def get_transforms(auto_augment, input_sizes, m, mean, n, std): if auto_augment: # AutoAugment + Cutout train_transforms = Compose([ RandomCrop(size=input_sizes, padding=4, fill=128), RandomHorizontalFlip(p=0.5), CIFAR10Policy(), ToTensor(), Normalize(mean=mean, std=std), Cutout(n_holes=1, length=16) ]) else: # RandAugment + Cutout train_transforms = Compose([ RandomCrop(size=input_sizes, padding=4, fill=128), RandomHorizontalFlip(p=0.5), RandomRandAugment(n=n, m_max=m), # This version includes cutout ToTensor(), Normalize(mean=mean, std=std) ]) test_transforms = Compose([ ToTensor(), Normalize(mean=mean, std=std) ]) return test_transforms, train_transforms
def __init__(self, cfg): super(DAGDataset, self).__init__() self.template_size = cfg.DAG.TRAIN.TEMPLATE_SIZE self.search_size = cfg.DAG.TRAIN.SEARCH_SIZE self.size = 25 self.stride = cfg.DAG.TRAIN.STRIDE self.color = cfg.DAG.DATASET.COLOR self.flip = cfg.DAG.DATASET.FLIP self.rotation = cfg.DAG.DATASET.ROTATION self.blur = cfg.DAG.DATASET.BLUR self.shift = cfg.DAG.DATASET.SHIFT self.scale = cfg.DAG.DATASET.SCALE self.gray = cfg.DAG.DATASET.GRAY self.label_smooth = cfg.DAG.DATASET.LABELSMOOTH self.mixup = cfg.DAG.DATASET.MIXUP self.cutout = cfg.DAG.DATASET.CUTOUT self.shift_s = cfg.DAG.DATASET.SHIFTs self.scale_s = cfg.DAG.DATASET.SCALEs self.grids() self.neg_num = cfg.DAG.TRAIN.NEG_NUM self.pos_num = cfg.DAG.TRAIN.POS_NUM self.total_num = cfg.DAG.TRAIN.TOTAL_NUM self.neg = cfg.DAG.DATASET.NEG self.transform_extra = transforms.Compose([ transforms.ToPILImage(), ] + ([ transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), ] if self.color > random.random() else []) + ([ transforms.RandomHorizontalFlip(), ] if self.flip > random.random() else []) + ([ transforms.RandomRotation(degrees=10), ] if self.rotation > random.random() else []) + ([ transforms.Grayscale(num_output_channels=3), ] if self.gray > random.random() else []) + ( [Cutout(n_holes=1, length=16 )] if self.cutout > random.random() else [])) print('train datas: {}'.format(cfg.DAG.TRAIN.WHICH_USE)) self.train_datas = [] start = 0 self.num = 0 for data_name in cfg.DAG.TRAIN.WHICH_USE: dataset = subData(cfg, data_name, start) self.train_datas.append(dataset) start += dataset.num self.num += dataset.num_use self._shuffle() print(cfg)
def main2(args): best_prec1 = 0.0 torch.backends.cudnn.deterministic = not args.cudaNoise torch.manual_seed(time.time()) if args.init != "None": args.name = "lrnet_%s" % args.init if args.tensorboard: configure(f"runs/{args.name}") dstype = nondigits(args.dataset) if dstype == "cifar": means = [125.3, 123.0, 113.9] stds = [63.0, 62.1, 66.7] elif dstype == "imgnet": means = [123.3, 118.1, 108.0] stds = [54.1, 52.6, 53.2] normalize = transforms.Normalize( mean=[x / 255.0 for x in means], std=[x / 255.0 for x in stds], ) writer = SummaryWriter(log_dir="runs/%s" % args.name, comment=str(args)) args.classes = onlydigits(args.dataset) if args.augment: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode="reflect").squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: transform_train = transforms.Compose( [transforms.ToTensor(), normalize]) if args.cutout: transform_train.transforms.append( Cutout(n_holes=args.n_holes, length=args.length)) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) kwargs = {"num_workers": 1, "pin_memory": True} assert dstype in ["cifar", "cinic", "imgnet"] if dstype == "cifar": train_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]("../data", train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs, ) val_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]("../data", train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs, ) elif dstype == "cinic": cinic_directory = "%s/cinic10" % args.dir cinic_mean = [0.47889522, 0.47227842, 0.43047404] cinic_std = [0.24205776, 0.23828046, 0.25874835] train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(cinic_directory + '/train', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=cinic_mean, std=cinic_std) ])), batch_size=args.batch_size, shuffle=True, **kwargs, ) print("Using CINIC10 dataset") val_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(cinic_directory + '/valid', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=cinic_mean, std=cinic_std) ])), batch_size=args.batch_size, shuffle=True, **kwargs, ) elif dstype == "imgnet": print("Using converted imagenet") train_loader = torch.utils.data.DataLoader( IMGNET("%s" % args.dir, train=True, transform=transform_train, target_transform=None, classes=args.classes), batch_size=args.batch_size, shuffle=True, **kwargs, ) val_loader = torch.utils.data.DataLoader( IMGNET("%s" % args.dir, train=False, transform=transform_test, target_transform=None, classes=args.classes), batch_size=args.batch_size, shuffle=True, **kwargs, ) else: print("Error matching dataset %s" % dstype) ##print("main bn:") ##print(args.batchnorm) ##print("main fixup:") ##print(args.fixup) if args.prune: pruner_state = getPruneMask(args) if pruner_state is None: print("Failed to prune network, aborting") return None if args.arch.lower() == "constnet": model = WideResNet( args.layers, args.classes, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, varnet=args.varnet, noise=args.noise, lrelu=args.lrelu, sigmaW=args.sigmaW, init=args.init, dropl1=args.dropl1, ) elif args.arch.lower() == "leakynet": model = LRNet( args.layers, args.classes, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, varnet=args.varnet, noise=args.noise, lrelu=args.lrelu, sigmaW=args.sigmaW, init=args.init, ) else: print("arch %s is not supported" % args.arch) return None ##draw(args,model) complex installation param_num = sum([p.data.nelement() for p in model.parameters()]) print(f"Number of model parameters: {param_num}") if torch.cuda.device_count() > 1: start = int(args.device[0]) end = int(args.device[2]) + 1 torch.cuda.set_device(start) dev_list = [] for i in range(start, end): dev_list.append("cuda:%d" % i) model = torch.nn.DataParallel(model, device_ids=dev_list) model = model.cuda() if args.freeze > 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['scale'], name.split('.')): cnt = cnt + 1 if cnt == args.freeze: break if cnt >= args.freeze_start: ## if intersection(['conv','conv1'],name.split('.')): ## print("Freezing Block: %s" % name.split('.')[1:3] ) if not intersection(['conv_res', 'fc'], name.split('.')): param.requires_grad = False print("Freezing Block: %s" % name) elif args.freeze < 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['scale'], name.split('.')): cnt = cnt + 1 if cnt > args.layers - 3 + args.freeze - 1: ## if intersection(['conv','conv1'],name.split('.')): ## print("Freezing Block: %s" % name ) if not intersection(['conv_res', 'fc'], name.split('.')): param.requires_grad = False print("Freezing Block: %s" % name) if args.res_freeze > 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['conv_res'], name.split('.')): cnt = cnt + 1 if cnt > args.res_freeze_start: param.requires_grad = False print("Freezing Block: %s" % name) if cnt >= args.res_freeze: break elif args.res_freeze < 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['conv_res'], name.split('.')): cnt = cnt + 1 if cnt > 3 + args.res_freeze: param.requires_grad = False print("Freezing Block: %s" % name) if args.prune: if args.prune_epoch >= 100: weightsFile = "runs/%s-net/checkpoint.pth.tar" % args.prune else: weightsFile = "runs/%s-net/model_epoch_%d.pth.tar" % ( args.prune, args.prune_epoch) if os.path.isfile(weightsFile): print(f"=> loading checkpoint {weightsFile}") checkpoint = torch.load(weightsFile) model.load_state_dict(checkpoint["state_dict"]) print( f"=> loaded checkpoint '{weightsFile}' (epoch {checkpoint['epoch']})" ) else: if args.prune_epoch == 0: print(f"=> No source data, Restarting network from scratch") else: print(f"=> no checkpoint found at {weightsFile}, aborting...") return None else: if args.resume: tarfile = "runs/%s-net/checkpoint.pth.tar" % args.resume if os.path.isfile(tarfile): print(f"=> loading checkpoint {args.resume}") checkpoint = torch.load(tarfile) args.start_epoch = checkpoint["epoch"] best_prec1 = checkpoint["best_prec1"] model.load_state_dict(checkpoint["state_dict"]) print( f"=> loaded checkpoint '{tarfile}' (epoch {checkpoint['epoch']})" ) else: print(f"=> no checkpoint found at {tarfile}, aborting...") return None cudnn.benchmark = True criterion = nn.CrossEntropyLoss().cuda() if args.optimizer.lower() == 'sgd': optimizer = torch.optim.SGD( model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay, ) elif args.optimizer.lower() == 'radam': optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) if args.prune and pruner_state is not None: cutoff_retrain = prunhild.cutoff.LocalRatioCutoff(args.cutoff) params_retrain = get_params_for_pruning(args, model) pruner_retrain = prunhild.pruner.CutoffPruner(params_retrain, cutoff_retrain) pruner_retrain.load_state_dict(pruner_state) pruner_retrain.prune(update_state=False) pruned_weights_count = count_pruned_weights(params_retrain, args.cutoff) params_left = param_num - pruned_weights_count print("Pruned %d weights, New model size: %d/%d (%d%%)" % (pruned_weights_count, params_left, param_num, int(100 * params_left / param_num))) else: pruner_retrain = None if args.eval: best_prec1 = validate(args, val_loader, model, criterion, 0, None) else: if args.varnet: save_checkpoint( args, { "epoch": 0, "state_dict": model.state_dict(), "best_prec1": 0.0, }, True, ) best_prec1 = 0.0 turns_above_50 = 0 for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(args, optimizer, epoch + 1) train(args, train_loader, model, criterion, optimizer, epoch, pruner_retrain, writer) prec1 = validate(args, val_loader, model, criterion, epoch, writer) correlation.measure_correlation(model, epoch, writer=writer) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if args.savenet: save_checkpoint( args, { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_prec1": best_prec1, }, is_best, ) if args.symmetry_break: if prec1 > 50.0: turns_above_50 += 1 if turns_above_50 > 3: return epoch writer.close() print("Best accuracy: ", best_prec1) return best_prec1
def load_datasets(): """Create data loaders for the CIFAR-10 dataset. Returns: Dict containing data loaders. """ normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) if args.cutout > 0: train_transform.transforms.append(Cutout(length=args.cutout)) valid_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize]) train_dataset = datasets.CIFAR10(root=args.data_path, train=True, transform=train_transform, download=True) valid_dataset = datasets.CIFAR10(root=args.data_path, train=True, transform=valid_transform, download=True) test_dataset = datasets.CIFAR10(root=args.data_path, train=False, transform=test_transform, download=True) train_indices = list(range(0, 45000)) valid_indices = list(range(45000, 50000)) train_subset = Subset(train_dataset, train_indices) valid_subset = Subset(valid_dataset, valid_indices) data_loaders = {} data_loaders['train_subset'] = torch.utils.data.DataLoader(dataset=train_subset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2) data_loaders['valid_subset'] = torch.utils.data.DataLoader(dataset=valid_subset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2, drop_last=True) data_loaders['train_dataset'] = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2) data_loaders['test_dataset'] = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) return data_loaders
def main(): args = parse_option() if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # set the data loader #data_folder = os.path.join(args.data_folder, 'train') data_folder = '/home/C2L/CXR/' image_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) if args.aug == 'NULL': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) elif args.aug == 'CJ': train_transform = transforms.Compose([ transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), transforms.RandomRotation(10), transforms.RandomGrayscale(p=0.2), transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: raise NotImplemented('augmentation not supported: {}'.format(args.aug)) train_transform.transforms.append(Cutout(n_holes=3, length=32)) train_dataset = ImageFolderInstance(data_folder, transform=train_transform, two_crop=args.c2l) print(len(train_dataset)) train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # create model and optimizer n_data = len(train_dataset) if args.model == 'resnet50': model = InsResNet50() if args.c2l: model_ema = InsResNet50() elif args.model == 'resnet50x2': model = InsResNet50(width=2) if args.c2l: model_ema = InsResNet50(width=2) elif args.model == 'resnet50x4': model = InsResNet50(width=4) if args.c2l: model_ema = InsResNet50(width=4) elif args.model == 'resnet18': model = InsResNet18(width=1) if args.c2l: model_ema = InsResNet18(width=1) elif args.model == 'densenet121': model = DenseNet121(isTrained=False) if args.c2l: model_ema = DenseNet121(isTrained=False) else: raise NotImplementedError('model not supported {}'.format(args.model)) # copy weights from `model' to `model_ema' if args.c2l: moment_update(model, model_ema, 0) # set the contrast memory and criterion if args.c2l: contrast = MemoryC2L(128, n_data, args.nce_k, args.nce_t, args.softmax).cuda(args.gpu) else: contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t, args.nce_m, args.softmax).cuda(args.gpu) criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data) criterion = criterion.cuda(args.gpu) model = model.cuda() if args.c2l: model_ema = model_ema.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True if args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.c2l: optimizer_ema = torch.optim.SGD(model_ema.parameters(), lr=0, momentum=0, weight_decay=0) model_ema, optimizer_ema = amp.initialize(model_ema, optimizer_ema, opt_level=args.opt_level) # optionally resume from a checkpoint args.start_epoch = 1 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) contrast.load_state_dict(checkpoint['contrast']) if args.c2l: model_ema.load_state_dict(checkpoint['model_ema']) if args.amp and checkpoint['opt'].amp: print('==> resuming amp state_dict') amp.load_state_dict(checkpoint['amp']) print("=> loaded successfully '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(args.resume)) # tensorboard #logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) # routine for epoch in range(args.start_epoch, args.epochs + 1): adjust_learning_rate(epoch, args, optimizer) print("==> training...") time1 = time.time() if args.c2l: loss, prob = train_C2L(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, args) else: loss, prob = train_ins(epoch, train_loader, model, contrast, criterion, optimizer, args) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) # tensorboard logger #logger.log_value('ins_loss', loss, epoch) #logger.log_value('ins_prob', prob, epoch) #logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) # save model if epoch % args.save_freq == 0: print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.c2l: state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state # saving the model print('==> Saving...') state = { 'opt': args, 'model': model.state_dict(), 'contrast': contrast.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, } if args.c2l: state['model_ema'] = model_ema.state_dict() if args.amp: state['amp'] = amp.state_dict() save_file = os.path.join(args.model_folder, 'current.pth') torch.save(state, save_file) if epoch % args.save_freq == 0: save_file = os.path.join( args.model_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # help release GPU memory del state torch.cuda.empty_cache()
def init(batch_size, state, mean, std, input_sizes, base, num_workers, train_set, val_set, rand_augment=True, n=1, m=1, dataset='cifar10'): # # Original transforms # train_transforms = Compose([ # Pad(padding=4, padding_mode='reflect'), # RandomHorizontalFlip(p=0.5), # RandomCrop(size=input_sizes), # ToTensor(), # Normalize(mean=mean, std=std) # ]) if rand_augment: # RandAugment + Cutout train_transforms = Compose([ RandomCrop(size=input_sizes, padding=4, fill=128), RandomHorizontalFlip(p=0.5), RandomRandAugment(n=n, m_max=m), # This version includes cutout ToTensor(), Normalize(mean=mean, std=std) ]) test_transforms = Compose([ToTensor(), Normalize(mean=mean, std=std)]) else: # AutoAugment + Cutout train_transforms = Compose([ RandomCrop(size=input_sizes, padding=4, fill=128), RandomHorizontalFlip(p=0.5), CIFAR10Policy(), ToTensor(), Normalize(mean=mean, std=std), Cutout(n_holes=1, length=16) ]) test_transforms = Compose([ToTensor(), Normalize(mean=mean, std=std)]) # Data sets if dataset == 'cifar10': if state == 1: train_set = CIFAR10(root=base, set_name=train_set, transform=train_transforms, label=True) test_set = CIFAR10(root=base, set_name=val_set, transform=test_transforms, label=True) else: raise NotImplementedError # Data loaders if state == 1: train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=num_workers * 2, shuffle=False) if state == 1: return train_loader, test_loader else: return test_loader
def main(): global args, best_prec1 args = parser.parse_args() if args.tensorboard: configure(f"runs/{args.name}") normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]], ) if args.augment: transform_train = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda( lambda x: F.pad( x.unsqueeze(0), (4, 4, 4, 4), mode="reflect" ).squeeze() ), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ) else: transform_train = transforms.Compose([transforms.ToTensor(), normalize]) if args.cutout: transform_train.transforms.append( Cutout(n_holes=args.n_holes, length=args.length) ) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) kwargs = {"num_workers": 1, "pin_memory": True} assert args.dataset == "cifar10" or args.dataset == "cifar100" train_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]( "../data", train=True, download=True, transform=transform_train ), batch_size=args.batch_size, shuffle=True, **kwargs, ) val_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]( "../data", train=False, transform=transform_test ), batch_size=args.batch_size, shuffle=True, **kwargs, ) model = WideResNet( args.layers, args.dataset == "cifar10" and 10 or 100, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, ) param_num = sum([p.data.nelement() for p in model.parameters()]) print(f"Number of model parameters: {param_num}") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = model.cuda() if args.resume: if os.path.isfile(args.resume): print(f"=> loading checkpoint {args.resume}") checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] best_prec1 = checkpoint["best_prec1"] model.load_state_dict(checkpoint["state_dict"]) print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") else: print(f"=> no checkpoint found at {args.resume}") cudnn.benchmark = True criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD( model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay, ) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch + 1) train(train_loader, model, criterion, optimizer, epoch) prec1 = validate(val_loader, model, criterion, epoch) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_prec1": best_prec1, }, is_best, ) print("Best accuracy: ", best_prec1)
transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean_vec, std_vec) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean_vec, std_vec) ]), } if args.cutout: data_transforms['train'].transforms.append( Cutout(n_holes=1, length=args.cutout_size)) train_dir = os.path.join(DATA_DIR, 'train') val_dir = os.path.join(DATA_DIR, 'val') train_dataset = datasets.ImageFolder(train_dir, data_transforms['train']) val_dataset = datasets.ImageFolder(val_dir, data_transforms['val']) print('train_dataset.size', len(train_dataset.samples)) print('val_dataset.size', len(val_dataset.samples)) image_datasets = {'train': train_dataset, 'val': val_dataset} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes print('class_names:', class_names) train_loader = torch.utils.data.DataLoader(train_dataset,
def main(args): global best_prec1 # CIFAR-10 Training & Test Transformation print( '. . . . . . . . . . . . . . . .PREPROCESSING DATA . . . . . . . . . . . . . . . .' ) TRAIN_transform = transforms.Compose([ transforms.Pad(4), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) if args.cutout: TRAIN_transform.transforms.append( Cutout(n_masks=args.n_masks, length=args.length)) VAL_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # CIFAR-10 dataset train_dataset = torchvision.datasets.CIFAR10(root='../data/', train=True, transform=TRAIN_transform, download=True) val_dataset = torchvision.datasets.CIFAR10(root='../data/', train=False, transform=VAL_transform, download=True) # Data loader train_loader = torch.utils.data.DataLoader(dataset=train_dataset, pin_memory=True, drop_last=True, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, pin_memory=True, batch_size=args.batch_size, shuffle=False) # Device Config device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.normalize == 'groupnorm': model = SEresnet_gn() elif args.normalize == 'groupnorm+ws': model = SEresnet_gn_ws() else: model = SEresnet() model = model.to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[250, 375], gamma=0.1) if args.evaluate: model.load_state_dict(torch.load('./save_model/model.th')) model.to(device) validation(args, val_loader, model, criterion) # Epoch = args.Epoch for epoch_ in range(0, args.Epoch): print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) train_one_epoch(args, train_loader, model, criterion, optimizer, epoch_) lr_schedule.step() prec1 = validation(args, val_loader, model, criterion) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if epoch_ > 0 and epoch_ % args.save_every == 0: save_checkpoint( { 'epoch': epoch_ + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.pt')) save_checkpoint( { 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'model.pt')) print('THE BEST MODEL prec@1 : {best_prec1:.3f} saved. '.format( best_prec1=best_prec1))