torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # --------------------------------------------------------------------------------------------------------------- # Loading the datasets # --------------------------------------------------------------------------------------------------------------- # We load the training dset and the repulsive dset if args.dataset.lower() == 'mnist': # Load transforms tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]) if final: dset = torchvision.datasets.MNIST('../../../datasets/MNIST', train=False, download=True, transform=tfms) else: full_dset = torchvision.datasets.MNIST('../../../datasets/MNIST', train=True, download=True, transform=tfms) _, dset, _, _ = dataset.train_valid_split(full_dset, split_fold=10, random_seed=dataset_seed) elif args.dataset.lower() == 'notmnist': tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4240,), (0.4583,))]) # Create the dset dset = dataset.notMNIST('../../../datasets/notMNIST', train=False, transform=tfms) elif args.dataset.lower() == 'kmnist': tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1832,), (0.3405,))]) # Create the dset dset = dataset.KujuMNIST_DS('../../../datasets/Kuzushiji-MNIST', train_or_test='test', download=True, tfms=tfms) elif args.dataset.lower() == 'emnist': tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1733,), (0.3317,))]) # Create the dset dset = torchvision.datasets.EMNIST('../../../datasets/emnist', split='letters', download=True, train=False, transform=tfms) else: print('Bad dataset: can\'t load dataset {}'.format(args.dataset))
def main(args): train_dir = args.train_dir train_csv = args.train_csv test_dir = args.test_dir test_csv = args.test_csv ratio = args.train_valid_ratio batch_size = args.batch_size epochs = args.epochs train_flag = args.train pretrain_weight = args.pretrain_weight verbose = args.verbose if (train_flag == 0): if (verbose == 2): print("Reading Training Data...") train_csv = pd.read_csv(train_csv) train_csv, valid_csv = train_valid_split(train_csv, ratio) train = RetinopathyDataset(train_csv, train_dir) valid = RetinopathyDataset(valid_csv, train_dir) if (verbose == 2): print("Creating DataLoader...") train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4) valid_dataloader = DataLoader(valid, batch_size=batch_size * 4, shuffle=False, num_workers=4) if (verbose == 2): print("Creating EfficientNet Model...") model = EfficientNetFinetune( level="efficientnet-b5", finetune=False, pretrain_weight="./weights/pretrained/aptos2018.pth") trainer = Trainer(model, train_dataloader, valid_dataloader, epochs, early_stop="QK", verbose=verbose) if (verbose == 2): print("Strat Training...") trainer.train() if (train_flag == 1): if (verbose == 2): print("Strat Predicting...") test_csv = pd.read_csv(test_csv) test = RetinopathyDataset(test_csv, test_dir, test=True) test_dataloader = DataLoader(test, batch_size=batch_size * 4, shuffle=False, num_workers=4) model = EfficientNetFinetune(level="efficientnet-b5", finetune=False, test=True, pretrain_weight=pretrain_weight) tester(model, test_dataloader, verbose)
def train(args, experiment=None, device=None): # --------------------------------------- # Definition of the hyperaparameters # --------------------------------------- if args.seed is not None: torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) # Loading dataset parameters if args.train.lower() == 'mnist': net = models.NNMNIST(28 * 28, 10) if args.beta > 0.0: prior = models.NNMNIST(28 * 28, 10) prior.eval() elif args.lambda_anchoring > 0.0: prior = deepcopy(net) prior.eval() else: prior = None net.to(device) if prior is not None: prior.to(device) # Load transforms tfms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307, ), (0.3081, )) ]) full_dset = torchvision.datasets.MNIST('../../../datasets/MNIST', train=True, download=True, transform=tfms) prepr = lambda x: x.view(-1, 28 * 28) else: raise ValueError('Bad training dataset selected: {}'.format( args.train.lower())) # Create training and validation split train_dset, val_dset, _, _ = dataset.train_valid_split( full_dset, split_fold=10, random_seed=args.dataset_seed) if args.bootstrapping: new_mapping = np.random.choice(np.asarray(train_dset.mapping), size=train_dset.length) train_dset.mapping = new_mapping train_loader, val_loader = torch.utils.data.DataLoader( train_dset, batch_size=args.batch_size_train, shuffle=True), torch.utils.data.DataLoader( val_dset, batch_size=args.batch_size_val, shuffle=True) # We create a configuration file with all the parameters model_name = 'repulsive_train:{}_repulsive:{}_lambda:{}_bandwidth:{}'.format( args.train.lower(), args.repulsive, args.lambda_repulsive, args.bandwidth_repulsive) if args.id is not None: model_name = model_name + '_{}'.format(args.id) savepath = Path(args.save_folder) try: if not Path.exists(savepath): os.makedirs(savepath) if not Path.exists( savepath / 'config.json'): # Only create json if it does not exist with open(savepath / 'config.json', 'w') as fd: json.dump(vars(args), fd) except FileExistsError: print('File already exists') pass # If the experiment is name we save it in results directly. # experiment.log_parameters(vars(args)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.set_num_threads(1) VAL_FREQ = 1 optimizer = optim.Adam(net.parameters(), lr=args.lr) # Load the reference net if args.repulsive is not None: if args.repulsive.lower() == 'fashionmnist': # For the repulsive loader we don't need to split into train and validation, we can use the full set tfms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.2859, ), (0.3530, )) ]) dset_repulsive = torchvision.datasets.FashionMNIST( '../../../datasets/FashionMNIST', train=True, download=False, transform=tfms) # Load the repulsive model raw_model = models.NNMNIST(28 * 28, 10) reference_net = tools.load_model(Path(args.reference_net), raw_model) reference_net.eval() else: raise ValueError('Bad repulsive dataset selected: {}'.format( args.repulsive.lower())) # Create repulsive sampler repulsive_loader = torch.utils.data.DataLoader( dset_repulsive, batch_size=args.batch_size_repulsive, shuffle=True) repulsive_sampler = sampler.repulsiveSampler( args.repulsive.upper(), dataloader=repulsive_loader, batch_size=args.batch_size_repulsive) print('Finished loading the datasets.') # Partial functions if args.repulsive is not None: _optimize = partial(tools.optimize, bandwidth_repulsive=args.bandwidth_repulsive, lambda_repulsive=args.lambda_repulsive) else: _optimize = tools.optimize # -------------------------------------------------------------------------------- # Training # -------------------------------------------------------------------------------- step = 0 if args.lambda_anchoring > 0.0: fac_norm = compute_norm_fac(net) for epoch in tqdm(range(args.n_epochs), desc='epochs'): # Training phase net.train() _tqdm = tqdm(train_loader, desc='batch') # experiment.log_current_epoch(epoch) for j, batch_raw in enumerate(_tqdm): if args.repulsive is not None: br = repulsive_sampler.sample_batch() batch_repulsive = br.to(device) # optimization part # prepare the batch, we get images not vectors ! x_raw, y = batch_raw if args.repulsive is not None: batch_repulsive = prepr(batch_repulsive) x_raw, y = prepr(x_raw), y.view(-1) batch = (x_raw.to(device), y.to(device)) if args.repulsive is not None: kwargs = { 'reference_net': reference_net, 'batch_repulsive': batch_repulsive } elif args.beta > 0.0: kwargs = {'beta': args.beta, 'prior': prior} elif args.lambda_anchoring > 0.0: kwargs = { 'lambda_anchoring': args.lambda_anchoring, 'prior': prior, 'fac_norm': fac_norm } else: kwargs = {} info_training = _optimize(net, optimizer, batch, add_repulsive_constraint=args.repulsive is not None, **kwargs) if args.verbose: _tqdm.set_description('Epoch {}/{}, loss: {:.4f}'.format( epoch + 1, args.n_epochs, info_training['loss'])) # # Log to Comet.ml # for k, v in info_training.items(): # experiment.log_metric(k, float(v), step=step) step += 1 if not Path.exists(savepath / 'models'): os.makedirs(savepath / 'models') if (epoch > 0 and epoch % args.save_freq == 0): model_path = savepath / 'models' / '{}_{}epochs.pt'.format( model_name, epoch + 1) if not Path.exists(model_path): torch.save(net.state_dict(), model_path) else: raise ValueError( 'Error trying to save file at location {}: File already exists' .format(model_path)) if epoch % VAL_FREQ == 0: # Evaluate on validation set xent = nn.CrossEntropyLoss() net.eval() total_val_loss, total_val_acc = 0.0, 0.0 n_val = len(val_loader.dataset) for j, batch_raw in enumerate(val_loader): x_raw, y = batch_raw len_batch = x_raw.size(0) x_raw, y = prepr(x_raw), y.view(-1) x, y = x_raw.to(device), y.to(device) y_logit = net(x) # logging total_val_loss += (len_batch / n_val) * xent( y_logit, y.view(-1)).item() total_val_acc += (y_logit.argmax(1) == y).float().sum().item() / n_val # Compute statistics print('Epoch {}/{}, val acc: {:.3f}, val loss: {:.3f}'.format( epoch + 1, args.n_epochs, total_val_acc, total_val_loss)) # experiment.log_metric("val_accuracy", total_val_acc) # experiment.log_metric("val_loss", total_val_loss) # POST-PROCESSING # Save the model try: dirname = 'models' if not Path.exists(savepath / dirname): os.makedirs(savepath / dirname) if args.beta > 0.0: dirname_priors = 'priors' if not Path.exists(savepath / dirname_priors): os.makedirs(savepath / dirname_priors) model_path = savepath / dirname / '{}_{}epochs.pt'.format( model_name, epoch + 1) if not Path.exists(model_path): torch.save(net.state_dict(), model_path) if args.beta > 0.0: prior_path = savepath / dirname_priors / '{}_{}epochs.pt'.format( model_name, epoch + 1) if not Path.exists(prior_path): torch.save(prior.state_dict(), prior_path) except FileExistsError: print('Error trying to save file at location {}: File already exists')
print(state) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.set_num_threads(1) tfms = trn.Compose([trn.ToTensor(), trn.Normalize((.1307, ), (.3081, ))]) full_train_data_in = dset.MNIST('../../../datasets/MNIST', train=True, transform=tfms) test_data = dset.MNIST('../../../datasets/MNIST', train=False, transform=tfms) num_classes = 10 # Splitting the data into train and validation train_data_in, val_data_in, _, _ = dataset.train_valid_split( full_train_data_in, split_fold=10, random_seed=args.dataset_seed) calib_indicator = '' if args.calibration: train_data_in, val_data = validation_split(train_data_in, val_share=0.1) calib_indicator = 'calib_' #tiny_images = TinyImages(transform=trn.Compose( # [trn.ToTensor(), trn.ToPILImage(), trn.Resize(28), # trn.Lambda(lambda x: x.convert('L', (0.2989, 0.5870, 0.1140, 0))), # trn.RandomHorizontalFlip(), trn.ToTensor()])) # # Instead of load tiny images, we load fashionmnist tfms_fashionmnist = trn.Compose( [trn.ToTensor(), trn.Normalize((.2859, ), (.3530, ))])