def __init__(self, opt): self.opt = opt self.n_cls = 64 self.train_trans, self.test_trans = self.transforms_options() self.train_loader = DataLoader(ImageNet(args=self.opt, partition='train', transform=self.train_trans), batch_size=Config.batch_size, shuffle=True, drop_last=True, num_workers=Config.num_workers) self.val_loader = DataLoader(ImageNet(args=self.opt, partition='val', transform=self.test_trans), batch_size=Config.batch_size // 2, shuffle=False, drop_last=False, num_workers=Config.num_workers // 2) self.meta_trainloader = DataLoader(MetaImageNet(args=self.opt, partition='train_phase_test', train_transform=self.train_trans, test_transform=self.test_trans, fix_seed=False), batch_size=self.opt.test_batch_size, shuffle=False, drop_last=False, num_workers=Config.num_workers) self.meta_valloader = DataLoader(MetaImageNet(args=self.opt, partition='val', train_transform=self.train_trans, test_transform=self.test_trans, fix_seed=False), batch_size=self.opt.test_batch_size, shuffle=False, drop_last=False, num_workers=Config.num_workers) self.meta_testloader = DataLoader(MetaImageNet(args=self.opt, partition='test', train_transform=self.train_trans, test_transform=self.test_trans, fix_seed=False), batch_size=self.opt.test_batch_size, shuffle=False, drop_last=False, num_workers=Config.num_workers) # model self.model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=self.n_cls).cuda() self.optimizer = optim.SGD(self.model.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4) self.criterion = nn.CrossEntropyLoss().cuda() pass
def main(): opt = parse_option() with open(f"{opt.tb_folder}/config.json", "w") as fo: fo.write(json.dumps(vars(opt), indent=4)) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) elif opt.dataset == "imagenet": train_trans, test_trans = transforms_options["A"] train_dataset = ImagenetFolder(root=os.path.join( opt.data_root, "train"), transform=train_trans) val_dataset = ImagenetFolder(root=os.path.join(opt.data_root, "val"), transform=test_trans) train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(val_dataset, batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) n_cls = 1000 else: raise NotImplementedError(opt.dataset) # model model = create_model(opt.model, n_cls, opt.dataset, use_srl=opt.srl) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) if opt.label_smoothing: criterion = LabelSmoothing(smoothing=opt.smoothing_ratio) elif opt.gce: criterion = GuidedComplementEntropy(alpha=opt.gce_alpha, classes=n_cls) else: criterion = nn.CrossEntropyLoss() if opt.opl: auxiliary_loss = OrthogonalProjectionLoss(use_attention=True) elif opt.popl: auxiliary_loss = PerpetualOrthogonalProjectionLoss(feat_dim=640) else: auxiliary_loss = None if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() if auxiliary_loss is not None: auxiliary_loss = auxiliary_loss.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) else: scheduler = None # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() if auxiliary_loss is not None: train_acc, train_loss, [train_cel, train_opl ] = train(epoch=epoch, train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, opt=opt, auxiliary=auxiliary_loss) else: train_acc, train_loss = train(epoch=epoch, train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, opt=opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('accuracy/train_acc', train_acc, epoch) logger.log_value('train_losses/loss', train_loss, epoch) if auxiliary_loss is not None: logger.log_value('train_losses/cel', train_cel, epoch) logger.log_value('train_losses/opl', train_opl, epoch) else: logger.log_value('train_losses/cel', train_loss, epoch) if auxiliary_loss is not None: test_acc, test_acc_top5, test_loss, [test_cel, test_opl] = \ validate(val_loader, model, criterion, opt, auxiliary=auxiliary_loss) else: test_acc, test_acc_top5, test_loss = validate( val_loader, model, criterion, opt) logger.log_value('accuracy/test_acc', test_acc, epoch) logger.log_value('accuracy/test_acc_top5', test_acc_top5, epoch) logger.log_value('test_losses/loss', test_loss, epoch) if auxiliary_loss is not None: logger.log_value('test_losses/cel', test_cel, epoch) logger.log_value('test_losses/opl', test_opl, epoch) else: logger.log_value('test_losses/cel', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model)) torch.save(state, save_file)
def main(): opt = parse_option() # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # model if not opt.load_latest: model = create_model(opt.model, n_cls, opt.dataset) else: latest_file = os.path.join(opt.save_folder, 'latest.pth') model = load_teacher(latest_file, n_cls, opt.dataset) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, test_acc_top5, test_loss = validate(val_loader, model, criterion, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) logger.log_value('test_loss', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) latest_file = os.path.join(opt.save_folder, 'latest.pth') os.symlink(save_file, latest_file) # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model)) torch.save(state, save_file)
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] if opt.distill in ['contrast']: train_set = ImageNet(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = ImageNet(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] if opt.distill in ['contrast']: train_set = TieredImageNet(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = TieredImageNet(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] if opt.distill in ['contrast']: train_set = CIFAR100(args=opt, partition=train_partition, transform=train_trans, is_sample=True, k=opt.nce_k) else: train_set = CIFAR100(args=opt, partition=train_partition, transform=train_trans) n_data = len(train_set) train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls, opt.dataset) model_s = create_model(opt.model_s, n_cls, opt.dataset) data = torch.randn(2, 3, 84, 84) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'contrast': criterion_kd = NCELoss(opt, n_data) embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim) embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'hint': criterion_kd = HintLoss() else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate**3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1) # routine: supervised model distillation for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) logger.log_value('test_loss', test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # save the last model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
def get_train_loaders(opt, train_partition, worker_init_fn=None): """ Create the training dataloaders """ if opt.double_transform: train_trans_standard, test_trans = transforms_options[opt.transform] train_trans_contrast = get_contrastive_aug(dataset=opt.dataset, aug_type=opt.aug_type) train_trans = TwoCropTransform(train_trans_standard, train_trans_contrast) else: train_trans, test_trans = transforms_options[opt.transform] # ImagetNet derivatives - miniImageNet if opt.dataset == 'miniImageNet': assert opt.transform == "A" train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, worker_init_fn=worker_init_fn) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn) if opt.use_trainval: n_cls = 80 else: n_cls = 64 # ImagetNet derivatives - tieredImageNet elif opt.dataset == 'tieredImageNet': assert opt.transform == "A" train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, worker_init_fn=worker_init_fn) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn) if opt.use_trainval: n_cls = 448 else: n_cls = 351 # CIFAR-100 derivatives - both CIFAR-FS & FC100 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': assert opt.transform == "D" or opt.transform == "Dcontrast" train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, worker_init_fn=worker_init_fn) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format(opt.dataset)) # For cross-domain experiments we train on all of the sets (train, val and test) elif opt.dataset == 'cross': assert opt.transform == "A" train_dataset = ImageNet(args=opt, partition='train', transform=train_trans) val_dataset = ImageNet(args=opt, partition='val', transform=train_trans) test_dataset = ImageNet(args=opt, partition='test', transform=train_trans) all_datasets = ConcatDataset([train_dataset, val_dataset, test_dataset]) train_loader = DataLoader(all_datasets, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, worker_init_fn=worker_init_fn) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn) n_cls = 64+16+20 # train + val + test else: raise NotImplementedError(opt.dataset) return train_loader, val_loader, n_cls
import pprint _utils_pp = pprint.PrettyPrinter() def pprint(x): _utils_pp.pprint(x) if __name__ == '__main__': opt = parse_option() pprint(vars(opt)) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers)
def main(): opt = parse_option() if opt.name is not None: wandb.init(name=opt.name) else: wandb.init() wandb.config.update(opt) # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format(opt.dataset)) elif opt.dataset == 'CUB_200_2011': train_trans, test_trans = transforms_options['C'] vocab = lang_utils.load_vocab(opt.lang_dir) if opt.lsl else None devocab = {v:k for k,v in vocab.items()} if opt.lsl else None train_loader = DataLoader(CUB2011(args=opt, partition=train_partition, transform=train_trans, vocab=vocab), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CUB2011(args=opt, partition='val', transform=test_trans, vocab=vocab), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) meta_testloader = DataLoader(MetaCUB2011(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans, vocab=vocab), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCUB2011(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans, vocab=vocab), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: raise NotImplementedError(opt.dataset) # no trainval supported yet n_cls = 150 else: n_cls = 100 else: raise NotImplementedError(opt.dataset) print('Amount training data: {}'.format(len(train_loader.dataset))) print('Amount val data: {}'.format(len(val_loader.dataset))) # model model = create_model(opt.model, n_cls, opt.dataset) # optimizer if opt.adam: optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) else: optimizer = optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) criterion = nn.CrossEntropyLoss() # lsl lang_model = None if opt.lsl: if opt.glove_init: vecs = lang_utils.glove_init(vocab, emb_size=opt.lang_emb_size) embedding_model = nn.Embedding( len(vocab), opt.lang_emb_size, _weight=vecs if opt.glove_init else None ) if opt.freeze_emb: embedding_model.weight.requires_grad = False lang_input_size = n_cls if opt.use_logit else 640 # 640 for resnet12 lang_model = TextProposal( embedding_model, input_size=lang_input_size, hidden_size=opt.lang_hidden_size, project_input=lang_input_size != opt.lang_hidden_size, rnn=opt.rnn_type, num_layers=opt.rnn_num_layers, dropout=opt.rnn_dropout, vocab=vocab, **lang_utils.get_special_indices(vocab) ) if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True if opt.lsl: embedding_model = embedding_model.cuda() lang_model = lang_model.cuda() # tensorboard #logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1) # routine: supervised pre-training best_val_acc = 0 for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss, train_lang_loss = train( epoch, train_loader, model, criterion, optimizer, opt, lang_model, devocab=devocab if opt.lsl else None ) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print("==> validating...") test_acc, test_acc_top5, test_loss, test_lang_loss = validate( val_loader, model, criterion, opt, lang_model ) # wandb log_metrics = { 'train_acc': train_acc, 'train_loss': train_loss, 'val_acc': test_acc, 'val_acc_top5': test_acc_top5, 'val_loss': test_loss } if opt.lsl: log_metrics['train_lang_loss'] = train_lang_loss log_metrics['val_lang_loss'] = test_lang_loss wandb.log(log_metrics, step=epoch) # # regular saving # if epoch % opt.save_freq == 0 and not opt.dryrun: # print('==> Saving...') # state = { # 'epoch': epoch, # 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), # } # save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) # torch.save(state, save_file) if test_acc > best_val_acc: wandb.run.summary['best_val_acc'] = test_acc wandb.run.summary['best_val_acc_epoch'] = epoch # save the last model state = { 'opt': opt, 'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(wandb.run.dir, '{}_last.pth'.format(opt.model)) torch.save(state, save_file) # evaluate on test set print("==> testing...") start = time.time() (test_acc, test_std), (test_acc5, test_std5) = meta_test(model, meta_testloader) test_time = time.time() - start print('Using logit layer for embedding') print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) print('test_acc top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(test_acc5, test_std5, test_time)) start = time.time() (test_acc_feat, test_std_feat), (test_acc5_feat, test_std5_feat) = meta_test(model, meta_testloader, use_logit=False) test_time = time.time() - start print('Using layer before logits for embedding') print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format( test_acc_feat, test_std_feat, test_time)) print('test_acc_feat top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format( test_acc5_feat, test_std5_feat, test_time)) wandb.run.summary['test_acc'] = test_acc wandb.run.summary['test_std'] = test_std wandb.run.summary['test_acc5'] = test_acc5 wandb.run.summary['test_std5'] = test_std5 wandb.run.summary['test_acc_feat'] = test_acc_feat wandb.run.summary['test_std_feat'] = test_std_feat wandb.run.summary['test_acc5_feat'] = test_acc5_feat wandb.run.summary['test_std5_feat'] = test_std5_feat
def main(): opt = parse_option() print(pp.pformat(vars(opt))) train_partition = "trainval" if opt.use_trainval else "train" if opt.dataset == "miniImageNet": train_trans, test_trans = transforms_options[opt.transform] if opt.augment == "none": train_train_trans = train_test_trans = test_trans elif opt.augment == "all": train_train_trans = train_test_trans = train_trans elif opt.augment == "spt": train_train_trans = train_trans train_test_trans = test_trans elif opt.augment == "qry": train_train_trans = test_trans train_test_trans = train_trans print("spt trans") print(train_train_trans) print("qry trans") print(train_test_trans) sub_batch_size, rmd = divmod(opt.batch_size, opt.apply_every) assert rmd == 0 print("Train sub batch-size:", sub_batch_size) meta_train_dataset = MetaImageNet( args=opt, partition="train", train_transform=train_train_trans, test_transform=train_test_trans, fname="miniImageNet_category_split_train_phase_%s.pickle", fix_seed=False, n_test_runs=10000000, # big number to never stop new_labels=False, ) meta_trainloader = DataLoader( meta_train_dataset, batch_size=sub_batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, pin_memory=True, ) meta_train_dataset_qry = MetaImageNet( args=opt, partition="train", train_transform=train_train_trans, test_transform=train_test_trans, fname="miniImageNet_category_split_train_phase_%s.pickle", fix_seed=False, n_test_runs=10000000, # big number to never stop new_labels=False, n_ways=opt.n_qry_way, n_shots=opt.n_qry_shot, n_queries=0, ) meta_trainloader_qry = DataLoader( meta_train_dataset_qry, batch_size=sub_batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, pin_memory=True, ) meta_val_dataset = MetaImageNet( args=opt, partition="val", train_transform=test_trans, test_transform=test_trans, fix_seed=False, n_test_runs=200, n_ways=5, n_shots=5, n_queries=15, ) meta_valloader = DataLoader( meta_val_dataset, batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, pin_memory=True, ) val_loader = DataLoader( ImageNet(args=opt, partition="val", transform=test_trans), batch_size=opt.sup_val_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, pin_memory=True, ) # if opt.use_trainval: # n_cls = 80 # else: # n_cls = 64 n_cls = len(meta_train_dataset.classes) print(n_cls) # x_spt, y_spt, x_qry, y_qry = next(iter(meta_trainloader)) # x_spt2, y_spt2, x_qry2, y_qry2 = next(iter(meta_trainloader_qry)) # print(x_spt, y_spt, x_qry, y_qry) # print(x_spt2, y_spt2, x_qry2, y_qry2) # print(x_spt.shape, y_spt.shape, x_qry.shape, y_qry.shape) # print(x_spt2.shape, y_spt2.shape, x_qry2.shape, y_qry2.shape) model = create_model( opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock, opt.track_stats, opt.initializer, opt.weight_norm, activation=opt.activation, normalization=opt.normalization, ) print(model) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): print(torch.cuda.get_device_name()) device = torch.device("cuda") # if opt.n_gpu > 1: # model = nn.DataParallel(model) model = model.to(device) criterion = criterion.to(device) cudnn.benchmark = True else: device = torch.device("cpu") print("Learning rate") print(opt.learning_rate) print("Inner Learning rate") print(opt.inner_lr) if opt.learn_lr: print("Optimizing learning rate") inner_lr = nn.Parameter(torch.tensor(opt.inner_lr), requires_grad=opt.learn_lr) optimizer = torch.optim.Adam( list(model.parameters()) + [inner_lr] if opt.learn_lr else model.parameters(), lr=opt.learning_rate, ) # classifier = model.classifier() inner_opt = torch.optim.SGD( model.classifier.parameters(), lr=opt.inner_lr, ) logger = SummaryWriter(logdir=opt.tb_folder, flush_secs=10, comment=opt.model_name) comet_logger = Experiment( api_key=os.environ["COMET_API_KEY"], project_name=opt.comet_project_name, workspace=opt.comet_workspace, disabled=not opt.logcomet, auto_metric_logging=False, ) comet_logger.set_name(opt.model_name) comet_logger.log_parameters(vars(opt)) comet_logger.set_model_graph(str(model)) if opt.cosine: eta_min = opt.learning_rate * opt.cosine_factor scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.num_steps, eta_min, -1) # routine: supervised pre-training data_sampler = iter(meta_trainloader) data_sampler_qry = iter(meta_trainloader_qry) pbar = tqdm( range(1, opt.num_steps + 1), miniters=opt.print_freq, mininterval=3, maxinterval=30, ncols=0, ) best_val_acc = 0.0 for step in pbar: if not opt.cosine: adjust_learning_rate(step, opt, optimizer) # print("==> training...") time1 = time.time() foa = 0.0 fol = 0.0 ioa = 0.0 iil = 0.0 fil = 0.0 iia = 0.0 fia = 0.0 for j in range(opt.apply_every): x_spt, y_spt, x_qry, y_qry = [ t.to(device) for t in next(data_sampler) ] x_qry2, y_qry2, _, _ = [ t.to(device) for t in next(data_sampler_qry) ] y_spt = y_spt.flatten(1) y_qry2 = y_qry2.flatten(1) x_qry = torch.cat((x_spt, x_qry, x_qry2), 1) y_qry = torch.cat((y_spt, y_qry, y_qry2), 1) if step == 1 and j == 0: print(x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size()) info = train_step( model, model.classifier, None, # inner_opt, inner_lr, x_spt, y_spt, x_qry, y_qry, reset_head=opt.reset_head, num_steps=opt.num_inner_steps, ) _foa = info["foa"] / opt.batch_size _fol = info["fol"] / opt.batch_size _ioa = info["ioa"] / opt.batch_size _iil = info["iil"] / opt.batch_size _fil = info["fil"] / opt.batch_size _iia = info["iia"] / opt.batch_size _fia = info["fia"] / opt.batch_size _fol.backward() foa += _foa.detach() fol += _fol.detach() ioa += _ioa.detach() iil += _iil.detach() fil += _fil.detach() iia += _iia.detach() fia += _fia.detach() optimizer.step() optimizer.zero_grad() inner_lr.data.clamp_(min=0.001) if opt.cosine: scheduler.step() if (step == 1) or (step % opt.eval_freq == 0): val_info = test_run( iter(meta_valloader), model, model.classifier, torch.optim.SGD(model.classifier.parameters(), lr=inner_lr.item()), num_inner_steps=opt.num_inner_steps_test, device=device, ) val_acc_feat, val_std_feat = meta_test( model, meta_valloader, use_logit=False, ) val_acc = val_info["outer"]["acc"].cpu() val_loss = val_info["outer"]["loss"].cpu() sup_acc, sup_acc_top5, sup_loss = validate( val_loader, model, criterion, print_freq=100000000, ) sup_acc = sup_acc.item() sup_acc_top5 = sup_acc_top5.item() print(f"\nValidation step {step}") print(f"MAML 5-way-5-shot accuracy: {val_acc.item()}") print(f"LR 5-way-5-shot accuracy: {val_acc_feat}+-{val_std_feat}") print( f"Supervised accuracy: Acc@1: {sup_acc} Acc@5: {sup_acc_top5} Loss: {sup_loss}" ) if val_acc_feat > best_val_acc: best_val_acc = val_acc_feat print( f"New best validation accuracy {best_val_acc.item()} saving checkpoints\n" ) # print(val_acc.item()) torch.save( { "opt": opt, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), "optimizer": optimizer.state_dict(), "step": step, "val_acc": val_acc, "val_loss": val_loss, "val_acc_lr": val_acc_feat, "sup_acc": sup_acc, "sup_acc_top5": sup_acc_top5, "sup_loss": sup_loss, }, os.path.join(opt.save_folder, "{}_best.pth".format(opt.model)), ) comet_logger.log_metrics( dict( fol=val_loss, foa=val_acc, acc_lr=val_acc_feat, sup_acc=sup_acc, sup_acc_top5=sup_acc_top5, sup_loss=sup_loss, ), step=step, prefix="val", ) logger.add_scalar("val_acc", val_acc, step) logger.add_scalar("val_loss", val_loss, step) logger.add_scalar("val_acc_lr", val_acc_feat, step) logger.add_scalar("sup_acc", sup_acc, step) logger.add_scalar("sup_acc_top5", sup_acc_top5, step) logger.add_scalar("sup_loss", sup_loss, step) if (step == 1) or (step % opt.eval_freq == 0) or (step % opt.print_freq == 0): tfol = fol.cpu() tfoa = foa.cpu() tioa = ioa.cpu() tiil = iil.cpu() tfil = fil.cpu() tiia = iia.cpu() tfia = fia.cpu() comet_logger.log_metrics( dict( fol=tfol, foa=tfoa, ioa=tfoa, iil=tiil, fil=tfil, iia=tiia, fia=tfia, ), step=step, prefix="train", ) logger.add_scalar("train_acc", tfoa.item(), step) logger.add_scalar("train_loss", tfol.item(), step) logger.add_scalar("train_ioa", tioa, step) logger.add_scalar("train_iil", tiil, step) logger.add_scalar("train_fil", tfil, step) logger.add_scalar("train_iia", tiia, step) logger.add_scalar("train_fia", tfia, step) pbar.set_postfix( # iol=f"{info['iol'].item():.2f}", fol=f"{tfol.item():.2f}", # ioa=f"{info['ioa'].item():.2f}", foa=f"{tfoa.item():.2f}", ioa=f"{tioa.item():.2f}", iia=f"{tiia.item():.2f}", fia=f"{tfia.item():.2f}", vl=f"{val_loss.item():.2f}", va=f"{val_acc.item():.2f}", valr=f"{val_acc_feat:.2f}", lr=f"{inner_lr.item():.4f}", vsa=f"{sup_acc:.2f}", # iil=f"{info['iil'].item():.2f}", # fil=f"{info['fil'].item():.2f}", # iia=f"{info['iia'].item():.2f}", # fia=f"{info['fia'].item():.2f}", # counter=info["counter"], refresh=True, ) # save the last model state = { "opt": opt, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), "optimizer": optimizer.state_dict(), "step": step, } save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model)) torch.save(state, save_file)
def get_dataloaders(opt): # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 no_sample = len( ImageNet(args=opt, partition=train_partition, transform=train_trans)) elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 no_sample = len( TieredImageNet(args=opt, partition=train_partition, transform=train_trans)) elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_trainloader = DataLoader(MetaCIFAR100(args=opt, partition='train', train_transform=train_trans, test_transform=test_trans), batch_size=1, shuffle=True, drop_last=False, num_workers=opt.num_workers) meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) no_sample = len( CIFAR100(args=opt, partition=train_partition, transform=train_trans)) else: raise NotImplementedError(opt.dataset) return train_loader, val_loader, meta_testloader, meta_valloader, n_cls, no_sample
def main(): opt = parse_option() # dataloader train_partition = "trainval" if opt.use_trainval else "train" if opt.dataset == "miniImageNet": train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader( ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, ) val_loader = DataLoader( ImageNet(args=opt, partition="val", transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, ) # meta_testloader = DataLoader( # MetaImageNet( # args=opt, # partition="test", # train_transform=train_trans, # test_transform=test_trans, # ), # batch_size=opt.test_batch_size, # shuffle=False, # drop_last=False, # num_workers=opt.num_workers, # ) # meta_valloader = DataLoader( # MetaImageNet( # args=opt, # partition="val", # train_transform=train_trans, # test_transform=test_trans, # ), # batch_size=opt.test_batch_size, # shuffle=False, # drop_last=False, # num_workers=opt.num_workers, # ) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == "tieredImageNet": train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader( TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, ) val_loader = DataLoader( TieredImageNet(args=opt, partition="train_phase_val", transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, ) meta_testloader = DataLoader( MetaTieredImageNet( args=opt, partition="test", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) meta_valloader = DataLoader( MetaTieredImageNet( args=opt, partition="val", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == "CIFAR-FS" or opt.dataset == "FC100": train_trans, test_trans = transforms_options["D"] train_loader = DataLoader( CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers, ) val_loader = DataLoader( CIFAR100(args=opt, partition="train", transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2, ) meta_testloader = DataLoader( MetaCIFAR100( args=opt, partition="test", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) meta_valloader = DataLoader( MetaCIFAR100( args=opt, partition="val", train_transform=train_trans, test_transform=test_trans, ), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers, ) if opt.use_trainval: n_cls = 80 else: if opt.dataset == "CIFAR-FS": n_cls = 64 elif opt.dataset == "FC100": n_cls = 60 else: raise NotImplementedError( "dataset not supported: {}".format(opt.dataset) ) else: raise NotImplementedError(opt.dataset) # model model = create_model(opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock) # optimizer if opt.adam: optimizer = torch.optim.Adam( model.parameters(), lr=opt.learning_rate, weight_decay=0.0005 ) else: optimizer = optim.SGD( model.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay, ) criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): if opt.n_gpu > 1: model = nn.DataParallel(model) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) comet_logger = Experiment( api_key=os.environ["COMET_API_KEY"], project_name=opt.comet_project_name, workspace=opt.comet_workspace, disabled=not opt.logcomet, ) comet_logger.set_name(opt.model_name) comet_logger.log_parameters(vars(opt)) # set cosine annealing scheduler if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate ** opt.cosine_factor) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, opt.epochs, eta_min, -1 ) # routine: supervised pre-training for epoch in range(1, opt.epochs + 1): if opt.cosine: scheduler.step() else: adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() with comet_logger.train(): train_acc, train_loss = train( epoch, train_loader, model, criterion, optimizer, opt ) comet_logger.log_metrics( {"acc": train_acc.cpu(), "loss_epoch": train_loss}, epoch=epoch ) time2 = time.time() print("epoch {}, total time {:.2f}".format(epoch, time2 - time1)) logger.log_value("train_acc", train_acc, epoch) logger.log_value("train_loss", train_loss, epoch) with comet_logger.validate(): test_acc, test_acc_top5, test_loss = validate( val_loader, model, criterion, opt ) comet_logger.log_metrics( {"acc": test_acc.cpu(), "acc_top5": test_acc_top5.cpu(), "loss": test_loss,}, epoch=epoch, ) logger.log_value("test_acc", test_acc, epoch) logger.log_value("test_acc_top5", test_acc_top5, epoch) logger.log_value("test_loss", test_loss, epoch) # regular saving if epoch % opt.save_freq == 0: print("==> Saving...") state = { "epoch": epoch, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join( opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch) ) torch.save(state, save_file) # save the last model state = { "opt": opt, "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(), } save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model)) torch.save(state, save_file)
def get_dataloaders(opt): # dataloader train_partition = 'trainval' if opt.use_trainval else 'train' if opt.dataset == 'toy': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100_toy(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100_toy(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) # train_trans, test_trans = transforms_test_options[opt.transform] # meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', # train_transform=train_trans, # test_transform=test_trans), # batch_size=opt.test_batch_size, shuffle=False, drop_last=False, # num_workers=opt.num_workers) # meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', # train_transform=train_trans, # test_transform=test_trans), # batch_size=opt.test_batch_size, shuffle=False, drop_last=False, # num_workers=opt.num_workers) n_cls = 5 return train_loader, val_loader, 5, 5, n_cls if opt.dataset == 'miniImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: n_cls = 64 elif opt.dataset == 'tieredImageNet': train_trans, test_trans = transforms_options[opt.transform] train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] meta_testloader = DataLoader(MetaTieredImageNet( args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaTieredImageNet( args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 448 else: n_cls = 351 elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': train_trans, test_trans = transforms_options['D'] train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.num_workers) val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, num_workers=opt.num_workers // 2) train_trans, test_trans = transforms_test_options[opt.transform] # ns = [opt.n_shots].copy() # opt.n_ways = 32 # opt.n_shots = 5 # opt.n_aug_support_samples = 2 meta_trainloader = DataLoader(MetaCIFAR100(args=opt, partition='train', train_transform=train_trans, test_transform=test_trans), batch_size=1, shuffle=True, drop_last=False, num_workers=opt.num_workers) # opt.n_ways = 5 # opt.n_shots = ns[0] # print(opt.n_shots) # opt.n_aug_support_samples = 5 meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans, test_transform=test_trans), batch_size=opt.test_batch_size, shuffle=False, drop_last=False, num_workers=opt.num_workers) if opt.use_trainval: n_cls = 80 else: if opt.dataset == 'CIFAR-FS': n_cls = 64 elif opt.dataset == 'FC100': n_cls = 60 else: raise NotImplementedError('dataset not supported: {}'.format( opt.dataset)) # return train_loader, val_loader, meta_trainloader, meta_testloader, meta_valloader, n_cls else: raise NotImplementedError(opt.dataset) return train_loader, val_loader, meta_testloader, meta_valloader, n_cls