Example #1
0
def main(raw_args=None):
    parser = argparse.ArgumentParser(
        description="Hyperspectral image classification with FixMatch")
    parser.add_argument(
        '--patch_size',
        type=int,
        default=5,
        help='Size of patch around each pixel taken for classification')
    parser.add_argument(
        '--center_pixel',
        action='store_false',
        help=
        'use if you only want to consider the label of the center pixel of a patch'
    )
    parser.add_argument('--batch_size',
                        type=int,
                        default=10,
                        help='Size of each batch for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of total epochs of training to run')
    parser.add_argument('--dataset',
                        type=str,
                        default='Salinas',
                        help='Name of dataset to run, Salinas or PaviaU')
    parser.add_argument('--cuda',
                        type=int,
                        default=-1,
                        help='what CUDA device to run on, -1 defaults to cpu')
    parser.add_argument('--warmup',
                        type=float,
                        default=0,
                        help='warmup epochs')
    parser.add_argument('--save',
                        action='store_true',
                        help='use to save model weights when running')
    parser.add_argument(
        '--test_stride',
        type=int,
        default=1,
        help='length of stride when sliding patch window over image for testing'
    )
    parser.add_argument(
        '--sampling_percentage',
        type=float,
        default=0.3,
        help=
        'percentage of dataset to sample for training (labeled and unlabeled included)'
    )
    parser.add_argument(
        '--sampling_mode',
        type=str,
        default='nalepa',
        help='how to sample data, disjoint, random, nalepa, or fixed')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='initial learning rate')
    parser.add_argument('--alpha',
                        type=float,
                        default=1.0,
                        help='beta distribution range')
    parser.add_argument(
        '--class_balancing',
        action='store_false',
        help='use to balance weights according to ratio in dataset')
    parser.add_argument(
        '--checkpoint',
        type=str,
        default=None,
        help='use to load model weights from a certain directory')
    #Augmentation arguments
    parser.add_argument('--flip_augmentation',
                        action='store_true',
                        help='use to flip augmentation data for use')
    parser.add_argument('--radiation_augmentation',
                        action='store_true',
                        help='use to radiation noise data for use')
    parser.add_argument('--mixture_augmentation',
                        action='store_true',
                        help='use to mixture noise data for use')
    parser.add_argument('--pca_augmentation',
                        action='store_true',
                        help='use to pca augment data for use')
    parser.add_argument(
        '--pca_strength',
        type=float,
        default=1.0,
        help='Strength of the PCA augmentation, defaults to 1.')
    parser.add_argument('--cutout_spatial',
                        action='store_true',
                        help='use to cutout spatial for data augmentation')
    parser.add_argument('--cutout_spectral',
                        action='store_true',
                        help='use to cutout spectral for data augmentation')
    parser.add_argument(
        '--augmentation_magnitude',
        type=int,
        default=1,
        help=
        'Magnitude of augmentation (so far only for cutout). Defualts to 1, min 1 and max 10.'
    )
    parser.add_argument('--spatial_combinations',
                        action='store_true',
                        help='use to spatial combine for data augmentation')
    parser.add_argument('--spectral_mean',
                        action='store_true',
                        help='use to spectal mean for data augmentation')
    parser.add_argument(
        '--moving_average',
        action='store_true',
        help='use to sprectral moving average for data augmentation')

    parser.add_argument('--results',
                        type=str,
                        default='results',
                        help='where to save results to (default results)')
    parser.add_argument('--save_dir',
                        type=str,
                        default='/saves/',
                        help='where to save models to (default /saves/)')
    parser.add_argument('--data_dir',
                        type=str,
                        default='/data/',
                        help='where to fetch data from (default /data/)')
    parser.add_argument('--load_file',
                        type=str,
                        default=None,
                        help='wihch file to load weights from (default None)')
    parser.add_argument(
        '--fold',
        type=int,
        default=0,
        help='Which fold to sample from if using Nalepas validation scheme')
    parser.add_argument(
        '--sampling_fixed',
        type=str,
        default='True',
        help=
        'Use to sample a fixed amount of samples for each class from Nalepa sampling'
    )
    parser.add_argument(
        '--samples_per_class',
        type=int,
        default=10,
        help=
        'Amount of samples to sample for each class when sampling a fixed amount. Defaults to 10.'
    )

    parser.add_argument(
        '--supervision',
        type=str,
        default='full',
        help=
        'check this more, use to make us of all labeled or not, full or semi')

    args = parser.parse_args(raw_args)

    device = utils.get_device(args.cuda)
    args.device = device

    #vis = visdom.Visdom()
    vis = None

    tensorboard_dir = str(args.results + '/' +
                          datetime.datetime.now().strftime("%m-%d-%X"))

    os.makedirs(tensorboard_dir, exist_ok=True)
    writer = SummaryWriter(tensorboard_dir)

    if args.sampling_mode == 'nalepa':
        train_img, train_gt, test_img, test_gt, label_values, ignored_labels, rgb_bands, palette = get_patch_data(
            args.dataset,
            args.patch_size,
            target_folder=args.data_dir,
            fold=args.fold)
        args.n_bands = train_img.shape[-1]
    else:
        img, gt, label_values, ignored_labels, rgb_bands, palette = get_dataset(
            args.dataset, target_folder=args.data_dir)
        args.n_bands = img.shape[-1]

    args.n_classes = len(label_values) - len(ignored_labels)
    args.ignored_labels = ignored_labels

    if palette is None:
        # Generate color palette
        palette = {0: (0, 0, 0)}
        for k, color in enumerate(
                sns.color_palette("hls",
                                  len(label_values) - 1)):
            palette[k + 1] = tuple(
                np.asarray(255 * np.array(color), dtype='uint8'))
    invert_palette = {v: k for k, v in palette.items()}

    def convert_to_color(x):
        return utils.convert_to_color_(x, palette=palette)

    def convert_from_color(x):
        return utils.convert_from_color_(x, palette=invert_palette)

    if args.sampling_mode == 'nalepa':
        print("{} samples selected (over {})".format(
            np.count_nonzero(train_gt),
            np.count_nonzero(train_gt) + np.count_nonzero(test_gt)))
        writer.add_text(
            'Amount of training samples',
            "{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                   np.count_nonzero(test_gt)))

        utils.display_predictions(convert_to_color(test_gt),
                                  vis,
                                  writer=writer,
                                  caption="Test ground truth")
    else:
        train_gt, test_gt = utils.sample_gt(gt,
                                            args.sampling_percentage,
                                            mode=args.sampling_mode)
        print("{} samples selected (over {})".format(
            np.count_nonzero(train_gt), np.count_nonzero(gt)))
        writer.add_text(
            'Amount of training samples',
            "{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                   np.count_nonzero(gt)))

        utils.display_predictions(convert_to_color(train_gt),
                                  vis,
                                  writer=writer,
                                  caption="Train ground truth")
        utils.display_predictions(convert_to_color(test_gt),
                                  vis,
                                  writer=writer,
                                  caption="Test ground truth")

    model = HamidaEtAl(args.n_bands,
                       args.n_classes,
                       patch_size=args.patch_size)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=0.0005)
    #loss_labeled = nn.CrossEntropyLoss(weight=weights)
    #loss_unlabeled = nn.CrossEntropyLoss(weight=weights, reduction='none')

    if args.sampling_mode == 'nalepa':
        #Get fixed amount of random samples for validation
        idx_sup, idx_val, idx_unsup = get_pixel_idx(train_img, train_gt,
                                                    args.ignored_labels,
                                                    args.patch_size)

        if args.sampling_fixed == 'True':
            unique_labels = np.zeros(len(label_values))
            new_idx_sup = []
            index = 0
            for p, x, y in idx_sup:
                label = train_gt[p, x, y]
                if unique_labels[label] < args.samples_per_class:
                    unique_labels[label] += 1
                    new_idx_sup.append([p, x, y])
                    np.delete(idx_sup, index)
                index += 1
            idx_unsup = np.concatenate((idx_sup, idx_unsup))
            idx_sup = np.asarray(new_idx_sup)

        writer.add_text(
            'Amount of labeled training samples',
            "{} samples selected (over {})".format(idx_sup.shape[0],
                                                   np.count_nonzero(train_gt)))
        train_labeled_gt = [
            train_gt[p_l, x_l, y_l] for p_l, x_l, y_l in idx_sup
        ]

        samples_class = np.zeros(args.n_classes)
        for c in np.unique(train_labeled_gt):
            samples_class[c - 1] = np.count_nonzero(train_labeled_gt == c)
        writer.add_text('Labeled samples per class', str(samples_class))
        print('Labeled samples per class: ' + str(samples_class))

        val_dataset = HyperX_patches(train_img,
                                     train_gt,
                                     idx_val,
                                     labeled='Val',
                                     **vars(args))
        val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)

        train_dataset = HyperX_patches(train_img,
                                       train_gt,
                                       idx_sup,
                                       labeled=True,
                                       **vars(args))
        train_loader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            #pin_memory=True, num_workers=5,
            shuffle=True,
            drop_last=True)

        amount_labeled = idx_sup.shape[0]
    else:
        train_labeled_gt, val_gt = utils.sample_gt(train_gt,
                                                   0.95,
                                                   mode=args.sampling_mode)

        val_dataset = HyperX(img, val_gt, labeled='Val', **vars(args))
        val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)

        writer.add_text(
            'Amount of labeled training samples',
            "{} samples selected (over {})".format(
                np.count_nonzero(train_labeled_gt),
                np.count_nonzero(train_gt)))
        samples_class = np.zeros(args.n_classes)
        for c in np.unique(train_labeled_gt):
            samples_class[c - 1] = np.count_nonzero(train_labeled_gt == c)
        writer.add_text('Labeled samples per class', str(samples_class))

        train_dataset = HyperX(img,
                               train_labeled_gt,
                               labeled=True,
                               **vars(args))
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       pin_memory=True,
                                       num_workers=5,
                                       shuffle=True,
                                       drop_last=True)

        utils.display_predictions(convert_to_color(train_labeled_gt),
                                  vis,
                                  writer=writer,
                                  caption="Labeled train ground truth")
        utils.display_predictions(convert_to_color(val_gt),
                                  vis,
                                  writer=writer,
                                  caption="Validation ground truth")

        amount_labeled = np.count_nonzero(train_labeled_gt)

    args.iterations = amount_labeled // args.batch_size
    args.total_steps = args.iterations * args.epochs
    args.scheduler = get_cosine_schedule_with_warmup(
        optimizer, args.warmup * args.iterations, args.total_steps)

    if args.class_balancing:
        weights_balance = utils.compute_imf_weights(train_gt,
                                                    len(label_values),
                                                    args.ignored_labels)
        args.weights = torch.from_numpy(weights_balance[1:])
        args.weights = args.weights.to(torch.float32)
    else:
        weights = torch.ones(args.n_classes)
        #weights[torch.LongTensor(args.ignored_labels)] = 0
        args.weights = weights

    args.weights = args.weights.to(args.device)
    criterion = nn.CrossEntropyLoss(weight=args.weights)
    loss_val = nn.CrossEntropyLoss(weight=args.weights)

    print(args)
    print("Network :")
    writer.add_text('Arguments', str(args))
    with torch.no_grad():
        for input, _ in train_loader:
            break
        #summary(model.to(device), input.size()[1:])
        #writer.add_graph(model.to(device), input)
        # We would like to use device=hyperparams['device'] altough we have
        # to wait for torchsummary to be fixed first.

    if args.load_file is not None:
        model.load_state_dict(torch.load(args.load_file))
    model.zero_grad()

    try:
        train(model,
              optimizer,
              criterion,
              loss_val,
              train_loader,
              writer,
              args,
              val_loader=val_loader,
              display=vis)
    except KeyboardInterrupt:
        # Allow the user to stop the training
        pass

    if args.sampling_mode == 'nalepa':
        probabilities = test(model, test_img, args)
    else:
        probabilities = test(model, img, args)
    prediction = np.argmax(probabilities, axis=-1)

    run_results = utils.metrics(prediction,
                                test_gt,
                                ignored_labels=args.ignored_labels,
                                n_classes=args.n_classes)

    mask = np.zeros(test_gt.shape, dtype='bool')
    for l in args.ignored_labels:
        mask[test_gt == l] = True
    prediction += 1
    prediction[mask] = 0

    color_prediction = convert_to_color(prediction)
    utils.display_predictions(color_prediction,
                              vis,
                              gt=convert_to_color(test_gt),
                              writer=writer,
                              caption="Prediction vs. test ground truth")

    utils.show_results(run_results,
                       vis,
                       writer=writer,
                       label_values=label_values)

    writer.close()

    return run_results
Example #2
0
        scaler = sklearn.preprocessing.StandardScaler()
        X_train = scaler.fit_transform(X_train)
        class_weight = 'balanced' if CLASS_BALANCING else None
        clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight,
                                                 learning_rate='optimal',
                                                 tol=1e-3,
                                                 average=10)
        clf.fit(X_train, y_train)
        save_model(clf, MODEL, DATASET)
        prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS)))
        prediction = prediction.reshape(img.shape[:2])
    else:  # 自定义算法
        # Neural network
        model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
        if CLASS_BALANCING:  #
            weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
            hyperparams['weights'] = torch.from_numpy(weights)
        # Split train set in train/val
        # train_gt, val_gt = sample_gt(train_gt, 0.9, mode='random')
        # Generate the dataset
        train_dataset = HyperX(img, train_gt, **hyperparams)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=hyperparams['batch_size'],
                                       pin_memory=hyperparams['device'],
                                       shuffle=True)
        # val_dataset = HyperX(img, val_gt, **hyperparams)
        # val_loader = data.DataLoader(val_dataset,
        #                              pin_memory=hyperparams['device'],
        #                              batch_size=hyperparams['batch_size'])

        print(hyperparams)
Example #3
0
def train_model(img, gt, hyperparams):
    """
    Function for model training.
    1) Data sampling into a training, a validation and a test set.
    2) Training a chosen model.
    3) Model evaluation.

    Arguments:
    img - dataset (hyperspectral image)
    gt - ground truth (labels)
    hyperparams - parameters of training
    SVM_GRID_PARAMS - parameters for SVM (if used)
    FOLDER - a path for datasets
    DATASET - name of the used dataset 
    set_parameters: option for loading a specific training and test set
    preprocessing_parameters: parameters of preprocessing
    """
    print("img.shape: {}".format(img.shape))
    print("gt.shape: {}".format(gt.shape))

    # all images should have 113 bands
    assert(img.shape[2] == 113)

    viz = None
    results = []
    # run the experiment several times
    for run in range(hyperparams['runs']):
        #############################################################################
        # Create a training and a test set
        if hyperparams['train_gt'] is not None and hyperparams['test_gt'] is not None:
            train_gt = open_file(hyperparams['train_gt'])
            test_gt = open_file(hyperparams['test_gt'])
        elif hyperparams['train_gt'] is not None:
            train_gt = open_file(hyperparams['train_gt'])
            test_gt = np.copy(gt)
            w, h = test_gt.shape
            test_gt[(train_gt > 0)[:w, :h]] = 0
        elif hyperparams['test_gt'] is not None:
            test_gt = open_file(hyperparams['test_gt'])
        else:
            # Choose type of data sampling
            if hyperparams['sampling_mode'] == 'uniform':
                train_gt, test_gt = select_subset(gt, hyperparams['training_sample'])
                check_split_correctness(gt, train_gt, test_gt, hyperparams['n_classes'])
            elif hyperparams['sampling_mode'] == 'fixed':
                # load fixed sets from a given path
                train_gt, test_gt = get_fixed_sets(run, hyperparams['sample_path'], hyperparams['dataset'])
                check_split_correctness(gt, train_gt, test_gt, hyperparams['n_classes'], 'fixed')
            else:
                train_gt, test_gt = sample_gt(gt,
                                              hyperparams['training_sample'],
                                              mode=hyperparams['sampling_mode'])
            
        print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                     np.count_nonzero(gt)))
        print("Running an experiment with the {} model".format(hyperparams['model']),
              "run {}/{}".format(run + 1, hyperparams['runs']))
        #######################################################################
        # Train a model

        if hyperparams['model'] == 'SVM_grid':
            print("Running a grid search SVM")
            # Grid search SVM (linear and RBF)
            X_train, y_train = build_dataset(img, train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.svm.SVC(class_weight=class_weight)
            clf = sklearn.model_selection.GridSearchCV(clf,
                                                       hyperparams['svm_grid_params'],
                                                       verbose=5,
                                                       n_jobs=4)
            clf.fit(X_train, y_train)
            print("SVM best parameters : {}".format(clf.best_params_))
            prediction = clf.predict(img.reshape(-1, hyperparams['n_bands']))
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = prediction.reshape(img.shape[:2])
        elif hyperparams['model'] == 'SVM':
            X_train, y_train = build_dataset(img, train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.svm.SVC(class_weight=class_weight)
            clf.fit(X_train, y_train)
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = clf.predict(img.reshape(-1, hyperparams['n_bands']))
            prediction = prediction.reshape(img.shape[:2])
        elif hyperparams['model'] == 'SGD':
            X_train, y_train = build_dataset(img, train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
            scaler = sklearn.preprocessing.StandardScaler()
            X_train = scaler.fit_transform(X_train)
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight,
                                                     learning_rate='optimal',
                                                     tol=1e-3,
                                                     average=10)
            clf.fit(X_train, y_train)
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = clf.predict(scaler.transform(img.reshape(-1,
                                                      hyperparams['n_bands'])))
            prediction = prediction.reshape(img.shape[:2])
        elif hyperparams['model'] == 'nearest':
            X_train, y_train = build_dataset(img,
                                             train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
            clf = sklearn.model_selection.GridSearchCV(clf,
                                                       {'n_neighbors': [1, 3, 5, 10, 20]},
                                                       verbose=5,
                                                       n_jobs=4)
            clf.fit(X_train, y_train)
            clf.fit(X_train, y_train)
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = clf.predict(img.reshape(-1, hyperparams['n_bands']))
            prediction = prediction.reshape(img.shape[:2])
        else:
            # Neural network
            model, optimizer, loss, hyperparams = get_model(hyperparams['model'], **hyperparams)
            if hyperparams['class_balancing']:
                weights = compute_imf_weights(train_gt,
                                              hyperparams['n_classes'],
                                              hyperparams['ignored_labels'])
                hyperparams['weights'] = torch.from_numpy(weights)
            # Split train set in train/val
            if hyperparams['sampling_mode'] in {'uniform', 'fixed'}:
                train_gt, val_gt = select_subset(train_gt, 0.95)
            else:
                train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
            # Generate the dataset
            train_dataset = HyperX(img, train_gt, **hyperparams)
            train_loader = data.DataLoader(train_dataset,
                                           batch_size=hyperparams['batch_size'],
                                           shuffle=True)
            val_dataset = HyperX(img, val_gt, **hyperparams)
            val_loader = data.DataLoader(val_dataset,
                                         batch_size=hyperparams['batch_size'])

            print(hyperparams)
            print("Network :")
            with torch.no_grad():
                for input, _ in train_loader:
                    break
                summary(model.to(hyperparams['device']), input.size()[1:])
                # We would like to use device=hyperparams['device'] altough we have
                # to wait for torchsummary to be fixed first.

            if hyperparams['checkpoint'] is not None:
                model.load_state_dict(torch.load(hyperparams['checkpoint']))

            try:
                train(model,
                      optimizer,
                      loss,
                      train_loader,
                      hyperparams['epoch'],
                      scheduler=hyperparams['scheduler'],
                      device=hyperparams['device'],
                      supervision=hyperparams['supervision'],
                      val_loader=val_loader,
                      display=viz,
                      rdir=hyperparams['rdir'],
                      model_name=hyperparams['model'],
                      preprocessing=hyperparams['preprocessing']['type'],
                      run=run)
            except KeyboardInterrupt:
                # Allow the user to stop the training
                pass

            probabilities = test(model, img, hyperparams)
            prediction = np.argmax(probabilities, axis=-1)

        #######################################################################
        # Evaluate the model
        # If test set is not empty
        if(np.unique(test_gt).shape[0] > 1):
            run_results = metrics(prediction,
                                  test_gt,
                                  ignored_labels=hyperparams['ignored_labels'],
                                  n_classes=hyperparams['n_classes'])

        mask = np.zeros(gt.shape, dtype='bool')
        for l in hyperparams['ignored_labels']:
            mask[gt == l] = True
        prediction[mask] = 0