def main(): global args, best_prec1 args = parser.parse_args() args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](low_dim=args.low_dim) if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolderInstance( traindir, transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2,1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( datasets.ImageFolderInstance(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda() criterion = NCECriterion(ndata).cuda() else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.evaluate: kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, lemniscate, criterion, optimizer, epoch) # evaluate on validation set prec1 = NN(epoch, model, lemniscate, train_loader, val_loader) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best) # evaluate KNN after last epoch kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
def main(): global args, best_prec1 args = parser.parse_args() # init seed my_whole_seed = 222 random.seed(my_whole_seed) np.random.seed(my_whole_seed) torch.manual_seed(my_whole_seed) torch.cuda.manual_seed_all(my_whole_seed) torch.cuda.manual_seed(my_whole_seed) np.random.seed(my_whole_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(my_whole_seed) for kk_time in range(args.seedstart, args.seedstart + 1): args.seed = kk_time args.result = args.result + str(args.seed) # create model model = models.__dict__[args.arch](low_dim=args.low_dim, multitask=args.multitask, showfeature=args.showfeature, domain=args.domain, args=args) model = torch.nn.DataParallel(model).cuda() print('Number of learnable params', get_learnable_para(model) / 1000000., " M") # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) aug = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)), # transforms.RandomHorizontalFlip(p=0.5), # get_color_distortion(s=1), # transforms.Lambda(lambda x: gaussian_blur(x)), # transforms.ToTensor(), # normalize]) aug_test = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), normalize]) # load dataset # import datasets.fundus_amd_syn_crossvalidation as medicaldata import datasets.fundus_amd_syn_crossvalidation_ind as medicaldata train_dataset = medicaldata.traindataset(root=args.data, transform=aug, train=True, args=args) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4, drop_last=True if args.multiaug else False, worker_init_fn=random.seed(my_whole_seed)) valid_dataset = medicaldata.traindataset(root=args.data, transform=aug_test, train=False, args=args) val_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4, worker_init_fn=random.seed(my_whole_seed)) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() if args.multitaskposrot: cls_criterion = nn.CrossEntropyLoss().cuda() else: cls_criterion = None if args.multitaskposrot: print("running multi task with miccai") criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda() elif args.synthesis: print("running synthesis") criterion = BatchCriterionFour(1, 0.1, args.batch_size, args).cuda() elif args.multiaug: print("running cvpr") criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda() else: criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: knn_num = 100 auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader, knn_num, args.nce_t, 2) f = open("savemodels/result.txt", "a+") f.write("auc: %.4f\n" % (auc)) f.write("acc: %.4f\n" % (acc)) f.write("pre: %.4f\n" % (precision)) f.write("recall: %.4f\n" % (recall)) f.write("f1score: %.4f\n" % (f1score)) f.close() return # mkdir result folder and tensorboard os.makedirs(args.result, exist_ok=True) writer = SummaryWriter("runs/" + str(args.result.split("/")[-1])) writer.add_text('Text', str(args)) # copy code import shutil, glob source = glob.glob("*.py") source += glob.glob("*/*.py") os.makedirs(args.result + "/code_file", exist_ok=True) for file in source: name = file.split("/")[0] if name == file: shutil.copy(file, args.result + "/code_file/") else: os.makedirs(args.result + "/code_file/" + name, exist_ok=True) shutil.copy(file, args.result + "/code_file/" + name) for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, args, [1000, 2000]) writer.add_scalar("lr", lr, epoch) # # train for one epoch loss = train(train_loader, model, lemniscate, criterion, cls_criterion, optimizer, epoch, writer) writer.add_scalar("train_loss", loss, epoch) # save checkpoint if epoch % 200 == 0 or (epoch in [1600, 1800, 2000]): auc, acc, precision, recall, f1score = kNN( args, model, lemniscate, train_loader, val_loader, 100, args.nce_t, 2) # save to txt writer.add_scalar("test_auc", auc, epoch) writer.add_scalar("test_acc", acc, epoch) writer.add_scalar("test_precision", precision, epoch) writer.add_scalar("test_recall", recall, epoch) writer.add_scalar("test_f1score", f1score, epoch) f = open(args.result + "/result.txt", "a+") f.write("epoch " + str(epoch) + "\n") f.write("auc: %.4f\n" % (auc)) f.write("acc: %.4f\n" % (acc)) f.write("pre: %.4f\n" % (precision)) f.write("recall: %.4f\n" % (recall)) f.write("f1score: %.4f\n" % (f1score)) f.close() save_checkpoint( { 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'optimizer': optimizer.state_dict(), }, filename=args.result + "/fold" + str(args.seedstart) + "-epoch-" + str(epoch) + ".pth.tar")
def main(): global args, best_prec1, best_prec1_past, best_prec1_future args = parser.parse_args() args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](low_dim=args.low_dim) if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.to(get_device(args.gpu)) else: model = torch.nn.DataParallel(model).to(get_device(args.gpu)) else: model.to(get_device(args.gpu)) model = torch.nn.parallel.DistributedDataParallel(model) # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = Dataset(traindir, n_frames) val_dataset = Dataset(valdir, n_frames) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, #(train_sampler is None), num_workers=args.workers) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() if args.nce_k > 0: lemniscate = NCEAverage(args.gpu, args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).to(get_device(args.gpu)) criterion = NCECriterion(ndata).to(get_device(args.gpu)) else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).to(get_device(args.gpu)) criterion = nn.CrossEntropyLoss().to(get_device(args.gpu)) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.evaluate: kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) return for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, lemniscate, criterion, optimizer, epoch) # evaluate on validation set prec1, prec1_past, prec1_future = NN(epoch, model, lemniscate, train_loader, val_loader) add_epoch_score('epoch_scores.txt', epoch, prec1) add_epoch_score('epoch_scores_past.txt', epoch, prec1_past) add_epoch_score('epoch_scores_future.txt', epoch, prec1_future) # Sascha: This is a bug because it seems prec1 or best_prec1 is a vector at some point with # more than one entry # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, epoch) is_best_past = prec1_past > best_prec1_past best_prec1_past = max(prec1_past, best_prec1_past) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1_past': best_prec1_past, 'optimizer': optimizer.state_dict(), }, is_best_past, epoch, best_mod='_past') is_best_future = prec1_future > best_prec1_future best_prec1_future = max(prec1_future, best_prec1_future) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1_future': best_prec1_future, 'optimizer': optimizer.state_dict(), }, is_best_future, epoch, best_mod='_future') # evaluate KNN after last epoch kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
def main(): global args, best_prec1 args = parser.parse_args() args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model # if args.pretrained: # print("=> using pre-trained model '{}'".format(args.arch)) # model = models.__dict__[args.arch](pretrained=True, finetune=args.finetune, low_dim= args.low_dim) # else: # print("=> creating model '{}'".format(args.arch)) # # model = models.__dict__[args.arch](low_dim=args.low_dim) # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # train_dataset = datasets.CombinedMaskDataset( # other_data_path = '/home/saschaho/Simcenter/found_label_imgs', # csv_root_folder='/home/saschaho/Simcenter/Floor_Elevation_Data/Streetview_Irma/Streetview_Irma/images', # data_csv='/home/saschaho/Simcenter/Building_Information_Prediction/all_bims_train.csv', # transform = transforms.Compose([ # transforms.RandomResizedCrop(224, scale=(0.2,1.)), # transforms.RandomGrayscale(p=0.2), # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ]),attribute = 'first_floor_elevation_ft', mask_images=True) # val_dataset = datasets.CombinedMaskDataset( # csv_root_folder='/home/saschaho/Simcenter/Floor_Elevation_Data/Streetview_Irma/Streetview_Irma/images', # data_csv='/home/saschaho/Simcenter/Building_Information_Prediction/all_bims_val.csv', # transform=transforms.Compose([ # transforms.Resize(256), # transforms.CenterCrop(224), # transforms.ToTensor(), # normalize, # ]), #attribute = 'first_floor_elevation_ft', mask_images=True) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.3, 1.)), transforms.RandomGrayscale(p=0.5), transforms.ColorJitter(0.5, 0.5, 0.5, 0.5), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), normalize]) train_dataset = First_Floor_Binary(args.attribute_name, args.train_data, args.image_folder, transform=train_transform, regression=args.regression, mask_buildings=args.mask_buildings, softmask=args.softmask) val_dataset = First_Floor_Binary(args.attribute_name, args.val_data, args.image_folder, transform=val_transform, regression=args.regression, mask_buildings=args.mask_buildings, softmask=args.softmask) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) model = ResidualAttentionModel_92_Small(args.low_dim, dropout=False) model = torch.nn.DataParallel(model).cuda() print('Train dataset instances: {}'.format(len(train_loader.dataset))) print('Val dataset instances: {}'.format(len(val_loader.dataset))) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda() criterion = NCECriterion(ndata).cuda() else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #optimizer = RAdam(model.parameters()) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] keyname = [ keyname for keyname in model.state_dict().keys() if 'fc.weight' in keyname ][0] lat_vec_len_model = model.state_dict()[keyname].shape[0] lat_vec_len_checkpoint = checkpoint['state_dict'][keyname].shape[0] low_dim_differ = False if lat_vec_len_model != lat_vec_len_checkpoint: low_dim_differ = True print( 'Warning: Latent vector sizes do not match. Assuming finetuning' ) print( 'Lemniscate will be trained from scratch with new optimizer.' ) del checkpoint['state_dict'][keyname] del checkpoint['state_dict'][keyname.replace('weight', 'bias')] missing_keys, unexpected_keys = model.load_state_dict( checkpoint['state_dict'], strict=False) if len(missing_keys) or len(unexpected_keys): print('Warning: Missing or unexpected keys found.') print('Missing: {}'.format(missing_keys)) print('Unexpected: {}'.format(unexpected_keys)) if not low_dim_differ: # The memory bank will be trained from scratch if # the low dim is different. Maybe later repopulated lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.evaluate: kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) return for epoch in range(args.start_epoch, args.epochs): # if args.distributed: # train_sampler.set_epoch(epoch) #adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, lemniscate, criterion, optimizer, epoch) # evaluate on validation set prec1 = NN(epoch, model, lemniscate, train_loader, val_loader) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, args.name) # evaluate KNN after last epoch kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
def full_training(args): if not os.path.isdir(args.expdir): os.makedirs(args.expdir) elif os.path.exists(args.expdir + '/results.npy'): return if 'ae' in args.task: os.mkdir(args.expdir + '/figs/') train_batch_size = args.train_batch_size // 4 if args.task == 'rot' else args.train_batch_size test_batch_size = args.test_batch_size // 4 if args.task == 'rot' else args.test_batch_size yield_indices = (args.task == 'inst_disc') datadir = args.datadir + args.dataset trainloader, valloader, num_classes = general_dataset_loader.prepare_data_loaders( datadir, image_dim=args.image_dim, yield_indices=yield_indices, train_batch_size=train_batch_size, test_batch_size=test_batch_size, train_on_10_percent=args.train_on_10, train_on_half_classes=args.train_on_half) _, testloader, _ = general_dataset_loader.prepare_data_loaders( datadir, image_dim=args.image_dim, yield_indices=yield_indices, train_batch_size=train_batch_size, test_batch_size=test_batch_size, ) args.num_classes = num_classes if args.task == 'rot': num_classes = 4 elif args.task == 'inst_disc': num_classes = args.low_dim if args.task == 'ae': net = models.AE([args.code_dim], image_dim=args.image_dim) elif args.task == 'jigsaw': net = JigsawModel(num_perms=args.num_perms, code_dim=args.code_dim, gray_prob=args.gray_prob, image_dim=args.image_dim) else: net = models.resnet26(num_classes, mlp_depth=args.mlp_depth, normalize=(args.task == 'inst_disc')) if args.task == 'inst_disc': train_lemniscate = LinearAverage(args.low_dim, trainloader.dataset.__len__(), args.nce_t, args.nce_m) train_lemniscate.cuda() args.train_lemniscate = train_lemniscate test_lemniscate = LinearAverage(args.low_dim, valloader.dataset.__len__(), args.nce_t, args.nce_m) test_lemniscate.cuda() args.test_lemniscate = test_lemniscate if args.source: try: old_net = torch.load(args.source) except: print("Falling back encoding") from functools import partial import pickle pickle.load = partial(pickle.load, encoding="latin1") pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") old_net = torch.load(args.source, map_location=lambda storage, loc: storage, pickle_module=pickle) # net.load_state_dict(old_net['net'].state_dict()) old_net = old_net['net'] if hasattr(old_net, "module"): old_net = old_net.module old_state_dict = old_net.state_dict() new_state_dict = net.state_dict() for key, weight in old_state_dict.items(): if 'linear' not in key: new_state_dict[key] = weight elif key == 'linears.0.weight' and weight.shape[0] == num_classes: new_state_dict['linears.0.0.weight'] = weight elif key == 'linears.0.bias' and weight.shape[0] == num_classes: new_state_dict['linears.0.0.bias'] = weight net.load_state_dict(new_state_dict) del old_net net = torch.nn.DataParallel(net).cuda() start_epoch = 0 if args.task in ['ae', 'inst_disc']: best_acc = np.inf else: best_acc = -1 results = np.zeros((4, start_epoch + args.nb_epochs)) net.cuda() cudnn.benchmark = True if args.task in ['ae']: args.criterion = nn.MSELoss() else: args.criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.wd) print("Start training") train_func = eval('utils_pytorch.train_' + args.task) test_func = eval('utils_pytorch.test_' + args.task) if args.test_first: with torch.no_grad(): test_func(0, valloader, net, best_acc, args, optimizer) for epoch in range(start_epoch, start_epoch + args.nb_epochs): utils_pytorch.adjust_learning_rate(optimizer, epoch, args) st_time = time.time() # Training and validation train_acc, train_loss = train_func(epoch, trainloader, net, args, optimizer) test_acc, test_loss, best_acc = test_func(epoch, valloader, net, best_acc, args, optimizer) # Record statistics results[0:2, epoch] = [train_loss, train_acc] results[2:4, epoch] = [test_loss, test_acc] np.save(args.expdir + '/results.npy', results) print('Epoch lasted {0}'.format(time.time() - st_time)) sys.stdout.flush() if (args.task == 'rot') and (train_acc >= 98) and args.early_stopping: break if args.task == 'inst_disc': args.train_lemniscate = None args.test_lemniscate = None else: best_net = torch.load(args.expdir + 'checkpoint.t7')['net'] if args.task in ['ae', 'inst_disc']: best_acc = np.inf else: best_acc = -1 final_acc, final_loss, _ = test_func(0, testloader, best_net, best_acc, args, None)
batch_size=100, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') ndata = trainset.__len__() print('==> Building model..') net = models.__dict__['ResNet34'](low_dim=args.low_dim) # define leminiscate if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m) else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m) if device == 'cuda': net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) cudnn.benchmark = True # Model if args.test_only or len(args.resume) > 0: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/' + args.resume) net.load_state_dict(checkpoint['net']) lemniscate = checkpoint['lemniscate'] best_acc = checkpoint['acc']
def main(args): # Data print('==> Preparing data..') _size = 32 transform_train = transforms.Compose([ transforms.Resize(size=_size), transforms.RandomResizedCrop(size=_size, scale=(0.2, 1.)), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomGrayscale(p=0.2), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.Resize(size=_size), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = datasets.CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4) testset = datasets.CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') ndata = trainset.__len__() print('==> Building model..') net = models.__dict__['ResNet18'](low_dim=args.low_dim) device = 'cuda' if torch.cuda.is_available() else 'cpu' if device == 'cuda': net = torch.nn.DataParallel(net, device_ids=range( torch.cuda.device_count())) cudnn.benchmark = True criterion = ICRcriterion() # define loss function: inner product loss within each mini-batch uel_criterion = BatchCriterion(args.batch_m, args.batch_t, args.batch_size, ndata) net.to(device) criterion.to(device) uel_criterion.to(device) best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch if args.test_only or len(args.resume) > 0: # Load checkpoint. model_path = 'checkpoint/' + args.resume print('==> Resuming from checkpoint..') assert os.path.isdir( args.model_dir), 'Error: no checkpoint directory found!' checkpoint = torch.load(model_path) net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] # define leminiscate if args.test_only and len(args.resume) > 0: trainFeatures, feature_index = compute_feature(trainloader, net, len(trainset), args) lemniscate = LinearAverage(torch.tensor(trainFeatures), args.low_dim, ndata, args.nce_t, args.nce_m) else: lemniscate = LinearAverage(torch.tensor([]), args.low_dim, ndata, args.nce_t, args.nce_m) lemniscate.to(device) # define optimizer optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) # optimizer2 = torch.optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) # test acc if args.test_only: acc = kNN(0, net, trainloader, testloader, 200, args.batch_t, ndata, low_dim=args.low_dim) exit(0) if len(args.resume) > 0: best_acc = best_acc start_epoch = start_epoch + 1 else: best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch icr2 = ICRDiscovery(ndata) # init_cluster_num = 20000 for round in range(5): for epoch in range(start_epoch, 200): #### get Features # trainFeatures are trainloader features and shuffle=True, so feature_index is match data trainFeatures, feature_index = compute_feature( trainloader, net, len(trainset), args) if round == 0: y = -1 * math.log10(ndata) / 200 * epoch + math.log10(ndata) cluster_num = int(math.pow(10, y)) if cluster_num <= args.nmb_cluster: cluster_num = args.nmb_cluster print('cluster number: ' + str(cluster_num)) ###clustering algorithm to use # faiss cluster deepcluster = clustering.__dict__[args.clustering]( int(cluster_num)) #### Features to clustering clustering_loss = deepcluster.cluster(trainFeatures, feature_index, verbose=args.verbose) L = np.array(deepcluster.images_lists) image_dict = deepcluster.images_dict print('create ICR ...') # icr = ICRDiscovery(ndata) # if args.test_only and len(args.resume) > 0: # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset, # cluster_ratio + epoch*((1-cluster_ratio)/250)) icrtime = time.time() # icr = cluster_assign(epoch, L, trainFeatures, feature_index, 1, 1) if epoch < args.warm_epoch: icr = cluster_assign(epoch, L, trainFeatures, feature_index, args.cluster_ratio, 1) else: icr = PreScore(epoch, L, image_dict, trainFeatures, feature_index, trainset, args.high_ratio, args.cluster_ratio, args.alpha, args.beta) print('calculate ICR time is: {}'.format(time.time() - icrtime)) writer.add_scalar('icr_time', (time.time() - icrtime), epoch + round * 200) else: cluster_num = args.nmb_cluster print('cluster number: ' + str(cluster_num)) ###clustering algorithm to use # faiss cluster deepcluster = clustering.__dict__[args.clustering]( int(cluster_num)) #### Features to clustering clustering_loss = deepcluster.cluster(trainFeatures, feature_index, verbose=args.verbose) L = np.array(deepcluster.images_lists) image_dict = deepcluster.images_dict print('create ICR ...') # icr = ICRDiscovery(ndata) # if args.test_only and len(args.resume) > 0: # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset, # cluster_ratio + epoch*((1-cluster_ratio)/250)) icrtime = time.time() # icr = cluster_assign(epoch, L, trainFeatures, feature_index, 1, 1) icr = PreScore(epoch, L, image_dict, trainFeatures, feature_index, trainset, args.high_ratio, args.cluster_ratio, args.alpha, args.beta) print('calculate ICR time is: {}'.format(time.time() - icrtime)) writer.add_scalar('icr_time', (time.time() - icrtime), epoch + round * 200) # else: # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset, 0.2 + epoch*0.004) # print(icr.neighbours) icr2 = train(epoch, net, optimizer, lemniscate, criterion, uel_criterion, trainloader, icr, icr2, args.stage_update, args.lr, device, round) print('----------Evaluation---------') start = time.time() acc = kNN(0, net, trainloader, testloader, 200, args.batch_t, ndata, low_dim=args.low_dim) print("Evaluation Time: '{}'s".format(time.time() - start)) writer.add_scalar('nn_acc', acc, epoch + round * 200) if acc > best_acc: print('Saving..') state = { 'net': net.state_dict(), 'acc': acc, 'epoch': epoch, } if not os.path.isdir(args.model_dir): os.mkdir(args.model_dir) torch.save(state, './checkpoint/ckpt_best_round_{}.t7'.format(round)) best_acc = acc state = { 'net': net.state_dict(), 'acc': acc, 'epoch': epoch, } torch.save(state, './checkpoint/ckpt_last_round_{}.t7'.format(round)) print( '[Round]: {} [Epoch]: {} \t accuracy: {}% \t (best acc: {}%)'. format(round, epoch, acc, best_acc))
# Model if args.test_only or len(args.resume)>0: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/'+args.resume) net = checkpoint['net'] lemniscate = checkpoint['lemniscate'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] else: print('==> Building model..') net = models.__dict__['ResNet50'](low_dim=args.low_dim) # define leminiscate lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum) # define loss function criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.targets)) if use_cuda: net.cuda() net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) lemniscate.cuda() criterion.cuda() cudnn.benchmark = True if args.test_only: acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature) sys.exit(0)
batch_size=100, shuffle=False, num_workers=4) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') ndata = trainset.__len__() print('==> Building model..') net = models.__dict__['ResNet18'](low_dim=args.low_dim) # define leminiscate if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m, None, lib.get_dev()) else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m) LA = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m) if device == 'cuda': # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) cudnn.benchmark = True # Model if args.test_only or len(args.resume) > 0: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/' + args.resume) net.load_state_dict(checkpoint['net']) lemniscate = checkpoint['lemniscate'] best_acc = checkpoint['acc']
def main(): global args, best_prec1 args = parser.parse_args() # Initialize distributed processing args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True, low_dim=args.low_dim) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](low_dim=args.low_dim) if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet stats std=[0.229, 0.224, 0.225]) # normalize = transforms.Normalize(mean=[0.234, 0.191, 0.159], # xView stats # std=[0.173, 0.143, 0.127]) print("Creating datasets") cj = args.color_jit train_dataset = datasets.ImageFolderInstance( traindir, transforms.Compose([ transforms.Resize((224, 224)), # transforms.Grayscale(3), # transforms.ColorJitter(cj, cj, cj, cj), #transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(45), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) elif args.balanced_sampling: print("Using balanced sampling") # Here's where we compute the weights for WeightedRandomSampler class_counts = {v: 0 for v in train_dataset.class_to_idx.values()} for path, ndx in train_dataset.samples: class_counts[ndx] += 1 total = float(np.sum([v for v in class_counts.values()])) class_probs = [ class_counts[ndx] / total for ndx in range(len(class_counts)) ] # make a list of class probabilities corresponding to the entries in train_dataset.samples reciprocal_weights = [ class_probs[idx] for i, (_, idx) in enumerate(train_dataset.samples) ] # weights are the reciprocal of the above weights = (1 / torch.Tensor(reciprocal_weights)) train_sampler = torch.utils.data.sampler.WeightedRandomSampler( weights, len(train_dataset), replacement=True) else: #if args.red_data is < 1, then the training is done with a subsamle of the total data. Otherwise it's the total data. data_size = len(train_dataset) sub_index = np.random.randint(0, data_size, round(args.red_data * data_size)) sub_index.sort() train_sampler = torch.utils.data.sampler.SubsetRandomSampler(sub_index) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) print("Training on", len(train_dataset.imgs), "images. Training batch size:", args.batch_size) if len(train_dataset.imgs) % args.batch_size != 0: print( "Warning: batch size doesn't divide the # of training images so ", len(train_dataset.imgs) % args.batch_size, "images will be skipped per epoch.") print("If you don't want to skip images, use a batch size in:", get_factors(len(train_dataset.imgs))) val_dataset = datasets.ImageFolderInstance( valdir, transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize, ])) val_bs = [ factor for factor in get_factors(len(val_dataset)) if factor < 500 ][-1] val_bs = 100 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=val_bs, shuffle=False, num_workers=args.workers, pin_memory=True) print("Validating on", len(val_dataset), "images. Validation batch size:", val_bs) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m) criterion = NCECriterion(ndata).cuda() else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss, verbose=False) optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] # best_prec1 = checkpoint['best_prec1'] lemniscate = checkpoint['lemniscate'] if args.select_load: pred = checkpoint['prediction'] print("=> loaded checkpoint '{}' (epoch {}, best_prec1 )".format( args.resume, checkpoint['epoch'])) #, checkpoint['best_prec1'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # optionally fine-tune a model trained on a different dataset elif args.fine_tune: print("=> loading checkpoint '{}'".format(args.fine_tune)) checkpoint = torch.load(args.fine_tune) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss, verbose=False) print("=> loaded checkpoint '{}' (epoch {})".format( args.fine_tune, checkpoint['epoch'])) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss, verbose=False) # Optionally recompute memory. If fine-tuning, then we must recompute memory if args.recompute_memory or args.fine_tune: # Aaron - Experiments show that iterating over torch.utils.data.DataLoader will skip the last few # unless the batch size evenly divides size of the data set. This shouldn't be the case # according to documentation, there's even a flag for drop_last, but it's not working # compute a good batch size for re-computing memory memory_bs = [ factor for factor in get_factors(len(train_loader.dataset)) if factor < 500 ][-1] print("Recomputing memory using", train_dataset.root, "with a batch size of", memory_bs) transform_bak = train_loader.dataset.transform train_loader.dataset.transform = val_loader.dataset.transform temploader = torch.utils.data.DataLoader( train_loader.dataset, batch_size=memory_bs, shuffle=False, num_workers=train_loader.num_workers, pin_memory=True) lemniscate.memory = torch.zeros(len(train_loader.dataset), args.low_dim).cuda() model.eval() with torch.no_grad(): for batch_idx, (inputs, targets, indexes) in enumerate(tqdm.tqdm(temploader)): batchSize = inputs.size(0) features = model(inputs) lemniscate.memory[batch_idx * batchSize:batch_idx * batchSize + batchSize, :] = features.data train_loader.dataset.transform = transform_bak model.train() cudnn.benchmark = True if args.evaluate: kNN(model, lemniscate, train_loader, val_loader, args.K, args.nce_t) return begin_train_time = datetime.datetime.now() # my_knn(model, lemniscate, train_loader, val_loader, args.K, args.nce_t, train_dataset, val_dataset) if args.tsne: labels = idx_to_name(train_dataset, args.graph_labels) tsne(lemniscate, args.tsne, labels) if args.pca: labels = idx_to_name(train_dataset, args.graph_labels) pca(lemniscate, labels) if args.view_knn: my_knn(model, lemniscate, train_loader, val_loader, args.K, args.nce_t, train_dataset, val_dataset) if args.kmeans: kmeans, yi = kmean(lemniscate, args.kmeans, 500, args.K, train_dataset) D, I = kmeans.index.search(lemniscate.memory.data.cpu().numpy(), 1) cent_group = {} data_cent = {} for n, i in enumerate(I): if i[0] not in cent_group.keys(): cent_group[i[0]] = [] cent_group[i[0]].append(n) data_cent[n] = i[0] train_sampler = torch.utils.data.sampler.SubsetRandomSampler( cent_group[0]) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) # lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m) # criterion = NCECriterion(ndata).cuda() # lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m) if args.tsne_grid: tsne_grid(val_loader, model) if args.h_cluster: for size in range(2, 3): # size = 20 kmeans, topk = kmean(lemniscate, size, 500, 10, train_dataset) respred = torch.tensor([]).cuda() lab, idx = [[] for i in range(2)] num = 0 ''' for p,index,label in pred: respred = torch.cat((respred,p)) if num == 0: lab = label else: lab += label idx.append(index) num+=1 ''' h_cluster(lemniscate, train_dataset, kmeans, topk, size) #, respred, lab, idx) # axis_explore(lemniscate, train_dataset) # kmeans_opt(lemniscate, 5) if args.select: if not args.select_load: pred = [] if args.select_size: size = int(args.select_size * ndata) else: size = round(ndata / 100.0) sub_sample = np.random.randint(0, ndata, size=size) train_sampler = torch.utils.data.sampler.SubsetRandomSampler( sub_sample) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) pred = div_train(train_loader, model, 0, pred) pred_features = [] pred_labels = [] pred_idx = [] for inst in pred: feat, idx, lab = list(inst) pred_features.append(feat) pred_labels.append(lab) pred_idx.append(idx.data.cpu()) if args.select_save: save_checkpoint( { 'epoch': args.start_epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'prediction': pred, 'lemniscate': lemniscate, 'optimizer': optimizer.state_dict(), }, 'select.pth.tar') min_idx = selection(pred_features, pred_idx, train_dataset, args.select_num, args.select_thresh) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(min_idx) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) lemniscate = NCEAverage(args.low_dim, ndata, 20, args.nce_t, args.nce_m) optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.1, weight_decay=0.00001) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss, verbose=False) for epoch in range(50): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch) if epoch % 1 == 0: save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'optimizer': optimizer.state_dict(), }) train(train_loader, model, lemniscate, criterion, optimizer, epoch) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(sub_index) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss, verbose=False) if args.kmeans_opt: kmeans_opt(lemniscate, 500) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch) if epoch % 1 == 0: # evaluate on validation set #prec1 = NN(epoch, model, lemniscate, train_loader, train_loader) # was evaluating on train # prec1 = kNN(model, lemniscate, train_loader, val_loader, args.K, args.nce_t) # prec1 really should be renamed to prec5 as kNN now returns top5 score, but # it won't be backward's compatible as earlier models were saved with "best_prec1" # remember best prec@1 and save checkpoint # is_best = prec1 > best_prec1 # best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, # 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }) # , is_best) # train for one epoch train(train_loader, model, lemniscate, criterion, optimizer, epoch) # kmeans,cent = kmeans() # group_train(train_loader, model, lemniscate, criterion, optimizer, epoch, kmeans, cent) # print elapsed time end_train_time = datetime.datetime.now() d = end_train_time - begin_train_time print( "Trained for %d epochs. Elapsed time: %s days, %.2dh: %.2dm: %.2ds" % (len(range(args.start_epoch, args.epochs)), d.days, d.seconds // 3600, (d.seconds // 60) % 60, d.seconds % 60))
def main(): global args, best_prec1 args = parser.parse_args() args.distributed = args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose( [ transforms.RandomResizedCrop(224, scale=(0.3, 1.)), transforms.RandomGrayscale(p=0.5), transforms.ColorJitter(0.5, 0.5, 0.5, 0.5), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), normalize]) val_transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), normalize]) train_dataset = Foundation_Type_Binary(args.train_data, transform=train_transform, mask_buildings=args.mask_buildings) val_dataset = Foundation_Type_Binary(args.val_data, transform=val_transform, mask_buildings=args.mask_buildings) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) model = resnet50(low_dim=args.low_dim) model = torch.nn.DataParallel(model).cuda() print ('Train dataset instances: {}'.format(len(train_loader.dataset))) print ('Val dataset instances: {}'.format(len(val_loader.dataset))) ndata = train_dataset.__len__() if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda() criterion = NCECriterion(ndata).cuda() else: lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] args.epochs = args.start_epoch + args.epochs best_prec1 = checkpoint['best_prec1'] missing_keys, unexpected_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) if len(missing_keys) or len(unexpected_keys): print('Warning: Missing or unexpected keys found.') print('Missing: {}'.format(missing_keys)) print('Unexpected: {}'.format(unexpected_keys)) low_dim_checkpoint = checkpoint['lemniscate'].memory.shape[1] if low_dim_checkpoint == args.low_dim: lemniscate = checkpoint['lemniscate'] else: print('Chosen low dim does not fit checkpoint. Assuming fine-tuning and not loading memory bank.') try: optimizer.load_state_dict(checkpoint['optimizer']) except ValueError: print('Training optimizer does not fit to checkpoint optimizer. Assuming fine-tuning and load optimizer from scratch. ') print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.evaluate: kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) return prec1 = NN(0, model, lemniscate, train_loader, val_loader) print('Start out precision: {}'.format(prec1)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, lemniscate, criterion, optimizer, epoch) # evaluate on validation set prec1 = NN(epoch, model, lemniscate, train_loader, val_loader) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best, args.name)
def main(): global args, best_prec1 args = parser.parse_args() my_whole_seed = 111 random.seed(my_whole_seed) np.random.seed(my_whole_seed) torch.manual_seed(my_whole_seed) torch.cuda.manual_seed_all(my_whole_seed) torch.cuda.manual_seed(my_whole_seed) np.random.seed(my_whole_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ['PYTHONHASHSEED'] = str(my_whole_seed) for kk_time in range(args.seedstart, args.seedend): args.seed = kk_time args.result = args.result + str(args.seed) # create model model = models.__dict__[args.arch](low_dim=args.low_dim, multitask=args.multitask, showfeature=args.showfeature, args=args) # # from models.Gresnet import ResNet18 # model = ResNet18(low_dim=args.low_dim, multitask=args.multitask) model = torch.nn.DataParallel(model).cuda() # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) aug = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)), # transforms.RandomHorizontalFlip(p=0.5), # get_color_distortion(s=1), # transforms.Lambda(lambda x: gaussian_blur(x)), # transforms.ToTensor(), # normalize]) # aug = transforms.Compose([transforms.RandomRotation(60), # transforms.RandomResizedCrop(224, scale=(0.6, 1.)), # transforms.RandomGrayscale(p=0.2), # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize]) aug_test = transforms.Compose( [transforms.Resize(224), transforms.ToTensor(), normalize]) # dataset import datasets.fundus_kaggle_dr as medicaldata train_dataset = medicaldata.traindataset(root=args.data, transform=aug, train=True, args=args) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8, drop_last=True if args.multiaug else False, worker_init_fn=random.seed(my_whole_seed)) valid_dataset = medicaldata.traindataset(root=args.data, transform=aug_test, train=False, test_type="amd", args=args) val_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, worker_init_fn=random.seed(my_whole_seed)) valid_dataset_gon = medicaldata.traindataset(root=args.data, transform=aug_test, train=False, test_type="gon", args=args) val_loader_gon = torch.utils.data.DataLoader( valid_dataset_gon, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, worker_init_fn=random.seed(my_whole_seed)) valid_dataset_pm = medicaldata.traindataset(root=args.data, transform=aug_test, train=False, test_type="pm", args=args) val_loader_pm = torch.utils.data.DataLoader( valid_dataset_pm, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8, worker_init_fn=random.seed(my_whole_seed)) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() local_lemniscate = None if args.multitaskposrot: print("running multi task with positive") criterion = BatchCriterionRot(1, 0.1, args.batch_size, args).cuda() elif args.domain: print("running domain with four types--unify ") from lib.BatchAverageFour import BatchCriterionFour # criterion = BatchCriterionTriple(1, 0.1, args.batch_size, args).cuda() criterion = BatchCriterionFour(1, 0.1, args.batch_size, args).cuda() elif args.multiaug: print("running multi task") criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda() else: criterion = nn.CrossEntropyLoss().cuda() if args.multitask: cls_criterion = nn.CrossEntropyLoss().cuda() else: cls_criterion = None optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: knn_num = 100 auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader, knn_num, args.nce_t, 2) return # mkdir result folder and tensorboard os.makedirs(args.result, exist_ok=True) writer = SummaryWriter("runs/" + str(args.result.split("/")[-1])) writer.add_text('Text', str(args)) # copy code import shutil, glob source = glob.glob("*.py") source += glob.glob("*/*.py") os.makedirs(args.result + "/code_file", exist_ok=True) for file in source: name = file.split("/")[0] if name == file: shutil.copy(file, args.result + "/code_file/") else: os.makedirs(args.result + "/code_file/" + name, exist_ok=True) shutil.copy(file, args.result + "/code_file/" + name) for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, args, [100, 200]) writer.add_scalar("lr", lr, epoch) # # train for one epoch loss = train(train_loader, model, lemniscate, local_lemniscate, criterion, cls_criterion, optimizer, epoch, writer) writer.add_scalar("train_loss", loss, epoch) # gap_int = 10 # if (epoch) % gap_int == 0: # knn_num = 100 # auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader, knn_num, args.nce_t, 2) # writer.add_scalar("test_auc", auc, epoch) # writer.add_scalar("test_acc", acc, epoch) # writer.add_scalar("test_precision", precision, epoch) # writer.add_scalar("test_recall", recall, epoch) # writer.add_scalar("test_f1score", f1score, epoch) # # auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_gon, # knn_num, args.nce_t, 2) # writer.add_scalar("gon/test_auc", auc, epoch) # writer.add_scalar("gon/test_acc", acc, epoch) # writer.add_scalar("gon/test_precision", precision, epoch) # writer.add_scalar("gon/test_recall", recall, epoch) # writer.add_scalar("gon/test_f1score", f1score, epoch) # auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_pm, # knn_num, args.nce_t, 2) # writer.add_scalar("pm/test_auc", auc, epoch) # writer.add_scalar("pm/test_acc", acc, epoch) # writer.add_scalar("pm/test_precision", precision, epoch) # writer.add_scalar("pm/test_recall", recall, epoch) # writer.add_scalar("pm/test_f1score", f1score, epoch) # save checkpoint save_checkpoint( { 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'optimizer': optimizer.state_dict(), }, filename=args.result + "/fold" + str(args.seedstart) + "-epoch-" + str(epoch) + ".pth.tar")
def build_model(): best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch if args.architecture == 'resnet18': net = models.__dict__['resnet18_cifar'](low_dim=args.low_dim) elif args.architecture == 'wrn-28-2': net = models.WideResNet(depth=28, num_classes=args.low_dim, widen_factor=2, dropRate=0).to(args.device) elif args.architecture == 'wrn-28-10': net = models.WideResNet(depth=28, num_classes=args.low_dim, widen_factor=10, dropRate=0).to(args.device) # define leminiscate if args.nce_k > 0: lemniscate = NCEAverage(args.low_dim, args.ndata, args.nce_k, args.nce_t, args.nce_m) else: lemniscate = LinearAverage(args.low_dim, args.ndata, args.nce_t, args.nce_m) if args.device == 'cuda': net = torch.nn.DataParallel(net, device_ids=range( torch.cuda.device_count())) cudnn.benchmark = True optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) # Model if args.test_only or len(args.resume) > 0: # Load checkpoint. print('==> Resuming from checkpoint..') checkpoint = torch.load(args.resume) net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) lemniscate = checkpoint['lemniscate'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] + 1 if args.lr_scheduler == 'multi-step': if args.epochs == 200: steps = [60, 120, 160] elif args.epochs == 600: steps = [180, 360, 480, 560] else: raise RuntimeError( f"need to config steps for epoch = {args.epochs} first.") scheduler = lr_scheduler.MultiStepLR(optimizer, steps, gamma=0.2, last_epoch=start_epoch - 1) elif args.lr_scheduler == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0.00001, last_epoch=start_epoch - 1) elif args.lr_scheduler == 'cosine-with-restart': scheduler = CosineAnnealingLRWithRestart(optimizer, eta_min=0.00001, last_epoch=start_epoch - 1) else: raise ValueError("not supported") # define loss function if hasattr(lemniscate, 'K'): criterion = NCECriterion(args.ndata) else: criterion = nn.CrossEntropyLoss() net.to(args.device) lemniscate.to(args.device) criterion.to(args.device) return net, lemniscate, optimizer, criterion, scheduler, best_acc, start_epoch
def main(): global args, best_prec1 args = parser.parse_args() print(args) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](low_dim=args.low_dim) if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolderInstance( traindir, transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_labels = torch.tensor(train_dataset.targets).long().cuda() train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) val_loader = torch.utils.data.DataLoader(datasets.ImageFolderInstance( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define lemniscate and loss function (criterion) ndata = train_dataset.__len__() lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda() rlb = ReliableSearch(ndata, args.low_dim, args.threshold_1, args.threshold_2, args.batch_size).cuda() criterion = ReliableCrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = 0 best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.evaluate: kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) return for rnd in range(args.start_round, args.rounds): if rnd > 0: memory = recompute_memory(model, lemniscate, train_loader, val_loader, args.batch_size, args.workers) num_reliable_1, consistency_1, num_reliable_2, consistency_2 = rlb.update( memory, train_labels) print( 'Round [%02d/%02d]\tReliable1: %.12f\tReliable2: %.12f\tConsistency1: %.12f\tConsistency2: %.12f' % (rnd, args.rounds, num_reliable_1, num_reliable_2, consistency_1, consistency_2)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, lemniscate, rlb, criterion, optimizer, epoch) # evaluate on validation set prec1 = NN(epoch, model, lemniscate, train_loader, val_loader) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), #}, is_best, filename='ckpts/%02d-%04d-checkpoint.pth.tar'%(rnd+1, epoch + 1)) }, is_best) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'lemniscate': lemniscate, 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best=False, filename='ckpts/%02d-checkpoint.pth.tar' % (rnd + 1)) # evaluate KNN after last epoch top1, top5 = kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) print('Round [%02d/%02d]\tTop1: %.2f\tTop5: %.2f' % (rnd + 1, args.rounds, top1, top5))
def main(): global args, best_prec1, min_avgloss args = parser.parse_args() input("Begin the {} time's training".format(args.train_time)) writer_log_dir = "/data/fhz/unsupervised_recommendation/idfe_runs/idfe_train_time:{}".format( args.train_time) writer = SummaryWriter(log_dir=writer_log_dir) if args.dataset == "lung": # build dataloader,val_dloader will be build in test function model = idfe.IdFe3d(feature_dim=args.latent_dim) model.encoder = torch.nn.DataParallel(model.encoder) model.linear_map = torch.nn.DataParallel(model.linear_map) model = model.cuda() train_datalist, test_datalist = multi_cross_validation() ndata = len(train_datalist) elif args.dataset == "gland": dataset_path = "/data/fhz/MICCAI2015/npy" model = idfe.IdFe2d(feature_dim=args.latent_dim) model.encoder = torch.nn.DataParallel(model.encoder) model.linear_map = torch.nn.DataParallel(model.linear_map) model = model.cuda() train_datalist = glob(path.join(dataset_path, "train", "*.npy")) ndata = len(train_datalist) else: raise FileNotFoundError("Dataset {} Not Found".format(args.dataset)) if args.nce_k > 0: """ Here we use NCE to calculate loss """ lemniscate = NCEAverage(args.latent_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda() criterion = NCECriterion(ndata).cuda() else: lemniscate = LinearAverage(args.latent_dim, ndata, args.nce_t, args.nce_m).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] min_avgloss = checkpoint['min_avgloss'] model.encoder.load_state_dict(checkpoint['encoder_state_dict']) model.linear_map.load_state_dict( checkpoint['linear_map_state_dict']) lemniscate = checkpoint['lemniscate'] optimizer.load_state_dict(checkpoint['optimizer']) train_datalist = checkpoint['train_datalist'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.dataset == "lung": train_dset = LungDataSet(data_path_list=train_datalist, augment_prob=args.aug_prob) train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) elif args.dataset == "gland": train_dset = GlandDataset(data_path_list=train_datalist, need_seg_label=False, augment_prob=args.aug_prob) train_dloader = DataLoader(dataset=train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) else: raise FileNotFoundError("Dataset {} Not Found".format(args.dataset)) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) epoch_loss = train(train_dloader, model=model, lemniscate=lemniscate, criterion=criterion, optimizer=optimizer, epoch=epoch, writer=writer, dataset=args.dataset) if (epoch + 1) % 5 == 0: if args.dataset == "lung": """ Here we define the best point as the minimum average epoch loss """ accuracy = list([]) # for i in range(5): # train_feature = lemniscate.memory.clone() # test_datalist = train_datalist[five_cross_idx[i]:five_cross_idx[i + 1]] # test_feature = train_feature[five_cross_idx[i]:five_cross_idx[i + 1], :] # train_indices = [train_datalist.index(d) for d in train_datalist if d not in test_datalist] # tmp_train_feature = torch.index_select(train_feature, 0, torch.tensor(train_indices).cuda()) # tmp_train_datalist = [train_datalist[i] for i in train_indices] # test_label = np.array( # [int(eval(re.match("(.*)_(.*)_annotations.npy", path.basename(raw_cube_path)).group(2)) > 3) # for raw_cube_path in test_datalist], dtype=np.float) # tmp_train_label = np.array( # [int(eval(re.match("(.*)_(.*)_annotations.npy", path.basename(raw_cube_path)).group(2)) > 3) # for raw_cube_path in tmp_train_datalist], dtype=np.float) # accuracy.append( # kNN(tmp_train_feature, tmp_train_label, test_feature, test_label, K=20, sigma=1 / 10)) # accuracy = mean(accuracy) is_best = (epoch_loss < min_avgloss) min_avgloss = min(epoch_loss, min_avgloss) save_checkpoint( { 'epoch': epoch + 1, "train_time": args.train_time, "encoder_state_dict": model.encoder.state_dict(), "linear_map_state_dict": model.linear_map.state_dict(), 'lemniscate': lemniscate, 'min_avgloss': min_avgloss, 'dataset': args.dataset, 'optimizer': optimizer.state_dict(), 'train_datalist': train_datalist }, is_best) # knn_text = "In epoch :{} the five cross validation accuracy is :{}".format(epoch, accuracy * 100.0) # # print(knn_text) # writer.add_text("knn/text", knn_text, epoch) # writer.add_scalar("knn/accuracy", accuracy, global_step=epoch) elif args.dataset == "gland": is_best = (epoch_loss < min_avgloss) min_avgloss = min(epoch_loss, min_avgloss) save_checkpoint( { 'epoch': epoch + 1, "train_time": args.train_time, "encoder_state_dict": model.encoder.state_dict(), "linear_map_state_dict": model.linear_map.state_dict(), 'lemniscate': lemniscate, 'min_avgloss': min_avgloss, 'dataset': args.dataset, 'optimizer': optimizer.state_dict(), 'train_datalist': train_datalist, }, is_best)