torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# ---------------------------------------------------------------------------------------------------------------
# Loading the datasets
# ---------------------------------------------------------------------------------------------------------------
# We load the training dset and the repulsive dset
if args.dataset.lower() == 'mnist':
    # Load transforms
    tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
    if final:
        dset = torchvision.datasets.MNIST('../../../datasets/MNIST', train=False, download=True, transform=tfms)
    else:
        full_dset = torchvision.datasets.MNIST('../../../datasets/MNIST', train=True, download=True, transform=tfms)
        _, dset, _, _ = dataset.train_valid_split(full_dset, split_fold=10, random_seed=dataset_seed)
elif args.dataset.lower() == 'notmnist':
    tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4240,), (0.4583,))])
    # Create the dset
    dset = dataset.notMNIST('../../../datasets/notMNIST', train=False, transform=tfms)
elif args.dataset.lower() == 'kmnist':
    tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1832,), (0.3405,))])
    # Create the dset
    dset = dataset.KujuMNIST_DS('../../../datasets/Kuzushiji-MNIST', train_or_test='test', download=True, tfms=tfms)
elif args.dataset.lower() == 'emnist':
    tfms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1733,), (0.3317,))])
    # Create the dset
    dset = torchvision.datasets.EMNIST('../../../datasets/emnist', split='letters', download=True, train=False, transform=tfms)
else:
    print('Bad dataset: can\'t load dataset {}'.format(args.dataset))
Exemple #2
0
def main(args):

    train_dir = args.train_dir
    train_csv = args.train_csv
    test_dir = args.test_dir
    test_csv = args.test_csv

    ratio = args.train_valid_ratio
    batch_size = args.batch_size
    epochs = args.epochs

    train_flag = args.train
    pretrain_weight = args.pretrain_weight
    verbose = args.verbose

    if (train_flag == 0):
        if (verbose == 2):
            print("Reading Training Data...")

        train_csv = pd.read_csv(train_csv)
        train_csv, valid_csv = train_valid_split(train_csv, ratio)

        train = RetinopathyDataset(train_csv, train_dir)
        valid = RetinopathyDataset(valid_csv, train_dir)

        if (verbose == 2):
            print("Creating DataLoader...")

        train_dataloader = DataLoader(train,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=4)
        valid_dataloader = DataLoader(valid,
                                      batch_size=batch_size * 4,
                                      shuffle=False,
                                      num_workers=4)

        if (verbose == 2):
            print("Creating EfficientNet Model...")

        model = EfficientNetFinetune(
            level="efficientnet-b5",
            finetune=False,
            pretrain_weight="./weights/pretrained/aptos2018.pth")

        trainer = Trainer(model,
                          train_dataloader,
                          valid_dataloader,
                          epochs,
                          early_stop="QK",
                          verbose=verbose)

        if (verbose == 2):
            print("Strat Training...")
        trainer.train()

    if (train_flag == 1):
        if (verbose == 2):
            print("Strat Predicting...")

        test_csv = pd.read_csv(test_csv)
        test = RetinopathyDataset(test_csv, test_dir, test=True)
        test_dataloader = DataLoader(test,
                                     batch_size=batch_size * 4,
                                     shuffle=False,
                                     num_workers=4)
        model = EfficientNetFinetune(level="efficientnet-b5",
                                     finetune=False,
                                     test=True,
                                     pretrain_weight=pretrain_weight)
        tester(model, test_dataloader, verbose)
Exemple #3
0
def train(args, experiment=None, device=None):
    # ---------------------------------------
    # Definition of the hyperaparameters
    # ---------------------------------------

    if args.seed is not None:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    # Loading dataset parameters
    if args.train.lower() == 'mnist':
        net = models.NNMNIST(28 * 28, 10)
        if args.beta > 0.0:
            prior = models.NNMNIST(28 * 28, 10)
            prior.eval()
        elif args.lambda_anchoring > 0.0:
            prior = deepcopy(net)
            prior.eval()
        else:
            prior = None
        net.to(device)
        if prior is not None:
            prior.to(device)
        # Load transforms
        tfms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        full_dset = torchvision.datasets.MNIST('../../../datasets/MNIST',
                                               train=True,
                                               download=True,
                                               transform=tfms)
        prepr = lambda x: x.view(-1, 28 * 28)
    else:
        raise ValueError('Bad training dataset selected: {}'.format(
            args.train.lower()))

    # Create training and validation split
    train_dset, val_dset, _, _ = dataset.train_valid_split(
        full_dset, split_fold=10, random_seed=args.dataset_seed)
    if args.bootstrapping:
        new_mapping = np.random.choice(np.asarray(train_dset.mapping),
                                       size=train_dset.length)
        train_dset.mapping = new_mapping
    train_loader, val_loader = torch.utils.data.DataLoader(
        train_dset, batch_size=args.batch_size_train,
        shuffle=True), torch.utils.data.DataLoader(
            val_dset, batch_size=args.batch_size_val, shuffle=True)

    # We create a configuration file with all the parameters
    model_name = 'repulsive_train:{}_repulsive:{}_lambda:{}_bandwidth:{}'.format(
        args.train.lower(), args.repulsive, args.lambda_repulsive,
        args.bandwidth_repulsive)
    if args.id is not None:
        model_name = model_name + '_{}'.format(args.id)

    savepath = Path(args.save_folder)
    try:
        if not Path.exists(savepath):
            os.makedirs(savepath)

        if not Path.exists(
                savepath /
                'config.json'):  # Only create json if it does not exist
            with open(savepath / 'config.json', 'w') as fd:
                json.dump(vars(args), fd)
    except FileExistsError:
        print('File already exists')
        pass

    # If the experiment is name we save it in results directly.
    # experiment.log_parameters(vars(args))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(1)

    VAL_FREQ = 1

    optimizer = optim.Adam(net.parameters(), lr=args.lr)

    # Load the reference net
    if args.repulsive is not None:
        if args.repulsive.lower() == 'fashionmnist':
            # For the repulsive loader we don't need to split into train and validation, we can use the full set
            tfms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.2859, ), (0.3530, ))
            ])
            dset_repulsive = torchvision.datasets.FashionMNIST(
                '../../../datasets/FashionMNIST',
                train=True,
                download=False,
                transform=tfms)

            # Load the repulsive model
            raw_model = models.NNMNIST(28 * 28, 10)
            reference_net = tools.load_model(Path(args.reference_net),
                                             raw_model)
            reference_net.eval()

        else:
            raise ValueError('Bad repulsive dataset selected: {}'.format(
                args.repulsive.lower()))

        # Create repulsive sampler
        repulsive_loader = torch.utils.data.DataLoader(
            dset_repulsive, batch_size=args.batch_size_repulsive, shuffle=True)
        repulsive_sampler = sampler.repulsiveSampler(
            args.repulsive.upper(),
            dataloader=repulsive_loader,
            batch_size=args.batch_size_repulsive)

    print('Finished loading the datasets.')

    # Partial functions
    if args.repulsive is not None:
        _optimize = partial(tools.optimize,
                            bandwidth_repulsive=args.bandwidth_repulsive,
                            lambda_repulsive=args.lambda_repulsive)
    else:
        _optimize = tools.optimize

    # --------------------------------------------------------------------------------
    # Training
    # --------------------------------------------------------------------------------
    step = 0

    if args.lambda_anchoring > 0.0:
        fac_norm = compute_norm_fac(net)

    for epoch in tqdm(range(args.n_epochs), desc='epochs'):
        # Training phase
        net.train()
        _tqdm = tqdm(train_loader, desc='batch')
        # experiment.log_current_epoch(epoch)
        for j, batch_raw in enumerate(_tqdm):

            if args.repulsive is not None:
                br = repulsive_sampler.sample_batch()
                batch_repulsive = br.to(device)

            # optimization part # prepare the batch, we get images not vectors !
            x_raw, y = batch_raw
            if args.repulsive is not None:
                batch_repulsive = prepr(batch_repulsive)
            x_raw, y = prepr(x_raw), y.view(-1)
            batch = (x_raw.to(device), y.to(device))

            if args.repulsive is not None:
                kwargs = {
                    'reference_net': reference_net,
                    'batch_repulsive': batch_repulsive
                }
            elif args.beta > 0.0:
                kwargs = {'beta': args.beta, 'prior': prior}
            elif args.lambda_anchoring > 0.0:
                kwargs = {
                    'lambda_anchoring': args.lambda_anchoring,
                    'prior': prior,
                    'fac_norm': fac_norm
                }
            else:
                kwargs = {}
            info_training = _optimize(net,
                                      optimizer,
                                      batch,
                                      add_repulsive_constraint=args.repulsive
                                      is not None,
                                      **kwargs)
            if args.verbose:
                _tqdm.set_description('Epoch {}/{}, loss: {:.4f}'.format(
                    epoch + 1, args.n_epochs, info_training['loss']))

            # # Log to Comet.ml
            # for k, v in info_training.items():
            #     experiment.log_metric(k, float(v), step=step)
            step += 1

        if not Path.exists(savepath / 'models'):
            os.makedirs(savepath / 'models')

        if (epoch > 0 and epoch % args.save_freq == 0):
            model_path = savepath / 'models' / '{}_{}epochs.pt'.format(
                model_name, epoch + 1)
            if not Path.exists(model_path):
                torch.save(net.state_dict(), model_path)
            else:
                raise ValueError(
                    'Error trying to save file at location {}: File already exists'
                    .format(model_path))

        if epoch % VAL_FREQ == 0:

            # Evaluate on validation set
            xent = nn.CrossEntropyLoss()
            net.eval()
            total_val_loss, total_val_acc = 0.0, 0.0
            n_val = len(val_loader.dataset)

            for j, batch_raw in enumerate(val_loader):
                x_raw, y = batch_raw
                len_batch = x_raw.size(0)
                x_raw, y = prepr(x_raw), y.view(-1)
                x, y = x_raw.to(device), y.to(device)
                y_logit = net(x)

                # logging
                total_val_loss += (len_batch / n_val) * xent(
                    y_logit, y.view(-1)).item()
                total_val_acc += (y_logit.argmax(1)
                                  == y).float().sum().item() / n_val

            # Compute statistics
            print('Epoch {}/{}, val acc: {:.3f}, val loss: {:.3f}'.format(
                epoch + 1, args.n_epochs, total_val_acc, total_val_loss))
            # experiment.log_metric("val_accuracy", total_val_acc)
            # experiment.log_metric("val_loss", total_val_loss)

    # POST-PROCESSING
    # Save the model
    try:
        dirname = 'models'
        if not Path.exists(savepath / dirname):
            os.makedirs(savepath / dirname)
        if args.beta > 0.0:
            dirname_priors = 'priors'
            if not Path.exists(savepath / dirname_priors):
                os.makedirs(savepath / dirname_priors)

        model_path = savepath / dirname / '{}_{}epochs.pt'.format(
            model_name, epoch + 1)
        if not Path.exists(model_path):
            torch.save(net.state_dict(), model_path)
        if args.beta > 0.0:
            prior_path = savepath / dirname_priors / '{}_{}epochs.pt'.format(
                model_name, epoch + 1)
            if not Path.exists(prior_path):
                torch.save(prior.state_dict(), prior_path)
    except FileExistsError:
        print('Error trying to save file at location {}: File already exists')
print(state)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

torch.set_num_threads(1)

tfms = trn.Compose([trn.ToTensor(), trn.Normalize((.1307, ), (.3081, ))])
full_train_data_in = dset.MNIST('../../../datasets/MNIST',
                                train=True,
                                transform=tfms)
test_data = dset.MNIST('../../../datasets/MNIST', train=False, transform=tfms)
num_classes = 10

# Splitting the data into train and validation
train_data_in, val_data_in, _, _ = dataset.train_valid_split(
    full_train_data_in, split_fold=10, random_seed=args.dataset_seed)

calib_indicator = ''
if args.calibration:
    train_data_in, val_data = validation_split(train_data_in, val_share=0.1)
    calib_indicator = 'calib_'

#tiny_images = TinyImages(transform=trn.Compose(
#    [trn.ToTensor(), trn.ToPILImage(), trn.Resize(28),
#     trn.Lambda(lambda x: x.convert('L', (0.2989, 0.5870, 0.1140, 0))),
#     trn.RandomHorizontalFlip(), trn.ToTensor()]))
#

# Instead of load tiny images, we load fashionmnist
tfms_fashionmnist = trn.Compose(
    [trn.ToTensor(), trn.Normalize((.2859, ), (.3530, ))])