def main(raw_args=None): parser = argparse.ArgumentParser( description="Hyperspectral image classification with FixMatch") parser.add_argument( '--patch_size', type=int, default=5, help='Size of patch around each pixel taken for classification') parser.add_argument( '--center_pixel', action='store_false', help= 'use if you only want to consider the label of the center pixel of a patch' ) parser.add_argument('--batch_size', type=int, default=10, help='Size of each batch for training') parser.add_argument('--epochs', type=int, default=10, help='number of total epochs of training to run') parser.add_argument('--dataset', type=str, default='Salinas', help='Name of dataset to run, Salinas or PaviaU') parser.add_argument('--cuda', type=int, default=-1, help='what CUDA device to run on, -1 defaults to cpu') parser.add_argument('--warmup', type=float, default=0, help='warmup epochs') parser.add_argument('--save', action='store_true', help='use to save model weights when running') parser.add_argument( '--test_stride', type=int, default=1, help='length of stride when sliding patch window over image for testing' ) parser.add_argument( '--sampling_percentage', type=float, default=0.3, help= 'percentage of dataset to sample for training (labeled and unlabeled included)' ) parser.add_argument( '--sampling_mode', type=str, default='nalepa', help='how to sample data, disjoint, random, nalepa, or fixed') parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate') parser.add_argument('--alpha', type=float, default=1.0, help='beta distribution range') parser.add_argument( '--class_balancing', action='store_false', help='use to balance weights according to ratio in dataset') parser.add_argument( '--checkpoint', type=str, default=None, help='use to load model weights from a certain directory') #Augmentation arguments parser.add_argument('--flip_augmentation', action='store_true', help='use to flip augmentation data for use') parser.add_argument('--radiation_augmentation', action='store_true', help='use to radiation noise data for use') parser.add_argument('--mixture_augmentation', action='store_true', help='use to mixture noise data for use') parser.add_argument('--pca_augmentation', action='store_true', help='use to pca augment data for use') parser.add_argument( '--pca_strength', type=float, default=1.0, help='Strength of the PCA augmentation, defaults to 1.') parser.add_argument('--cutout_spatial', action='store_true', help='use to cutout spatial for data augmentation') parser.add_argument('--cutout_spectral', action='store_true', help='use to cutout spectral for data augmentation') parser.add_argument( '--augmentation_magnitude', type=int, default=1, help= 'Magnitude of augmentation (so far only for cutout). Defualts to 1, min 1 and max 10.' ) parser.add_argument('--spatial_combinations', action='store_true', help='use to spatial combine for data augmentation') parser.add_argument('--spectral_mean', action='store_true', help='use to spectal mean for data augmentation') parser.add_argument( '--moving_average', action='store_true', help='use to sprectral moving average for data augmentation') parser.add_argument('--results', type=str, default='results', help='where to save results to (default results)') parser.add_argument('--save_dir', type=str, default='/saves/', help='where to save models to (default /saves/)') parser.add_argument('--data_dir', type=str, default='/data/', help='where to fetch data from (default /data/)') parser.add_argument('--load_file', type=str, default=None, help='wihch file to load weights from (default None)') parser.add_argument( '--fold', type=int, default=0, help='Which fold to sample from if using Nalepas validation scheme') parser.add_argument( '--sampling_fixed', type=str, default='True', help= 'Use to sample a fixed amount of samples for each class from Nalepa sampling' ) parser.add_argument( '--samples_per_class', type=int, default=10, help= 'Amount of samples to sample for each class when sampling a fixed amount. Defaults to 10.' ) parser.add_argument( '--supervision', type=str, default='full', help= 'check this more, use to make us of all labeled or not, full or semi') args = parser.parse_args(raw_args) device = utils.get_device(args.cuda) args.device = device #vis = visdom.Visdom() vis = None tensorboard_dir = str(args.results + '/' + datetime.datetime.now().strftime("%m-%d-%X")) os.makedirs(tensorboard_dir, exist_ok=True) writer = SummaryWriter(tensorboard_dir) if args.sampling_mode == 'nalepa': train_img, train_gt, test_img, test_gt, label_values, ignored_labels, rgb_bands, palette = get_patch_data( args.dataset, args.patch_size, target_folder=args.data_dir, fold=args.fold) args.n_bands = train_img.shape[-1] else: img, gt, label_values, ignored_labels, rgb_bands, palette = get_dataset( args.dataset, target_folder=args.data_dir) args.n_bands = img.shape[-1] args.n_classes = len(label_values) - len(ignored_labels) args.ignored_labels = ignored_labels if palette is None: # Generate color palette palette = {0: (0, 0, 0)} for k, color in enumerate( sns.color_palette("hls", len(label_values) - 1)): palette[k + 1] = tuple( np.asarray(255 * np.array(color), dtype='uint8')) invert_palette = {v: k for k, v in palette.items()} def convert_to_color(x): return utils.convert_to_color_(x, palette=palette) def convert_from_color(x): return utils.convert_from_color_(x, palette=invert_palette) if args.sampling_mode == 'nalepa': print("{} samples selected (over {})".format( np.count_nonzero(train_gt), np.count_nonzero(train_gt) + np.count_nonzero(test_gt))) writer.add_text( 'Amount of training samples', "{} samples selected (over {})".format(np.count_nonzero(train_gt), np.count_nonzero(test_gt))) utils.display_predictions(convert_to_color(test_gt), vis, writer=writer, caption="Test ground truth") else: train_gt, test_gt = utils.sample_gt(gt, args.sampling_percentage, mode=args.sampling_mode) print("{} samples selected (over {})".format( np.count_nonzero(train_gt), np.count_nonzero(gt))) writer.add_text( 'Amount of training samples', "{} samples selected (over {})".format(np.count_nonzero(train_gt), np.count_nonzero(gt))) utils.display_predictions(convert_to_color(train_gt), vis, writer=writer, caption="Train ground truth") utils.display_predictions(convert_to_color(test_gt), vis, writer=writer, caption="Test ground truth") model = HamidaEtAl(args.n_bands, args.n_classes, patch_size=args.patch_size) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0.0005) #loss_labeled = nn.CrossEntropyLoss(weight=weights) #loss_unlabeled = nn.CrossEntropyLoss(weight=weights, reduction='none') if args.sampling_mode == 'nalepa': #Get fixed amount of random samples for validation idx_sup, idx_val, idx_unsup = get_pixel_idx(train_img, train_gt, args.ignored_labels, args.patch_size) if args.sampling_fixed == 'True': unique_labels = np.zeros(len(label_values)) new_idx_sup = [] index = 0 for p, x, y in idx_sup: label = train_gt[p, x, y] if unique_labels[label] < args.samples_per_class: unique_labels[label] += 1 new_idx_sup.append([p, x, y]) np.delete(idx_sup, index) index += 1 idx_unsup = np.concatenate((idx_sup, idx_unsup)) idx_sup = np.asarray(new_idx_sup) writer.add_text( 'Amount of labeled training samples', "{} samples selected (over {})".format(idx_sup.shape[0], np.count_nonzero(train_gt))) train_labeled_gt = [ train_gt[p_l, x_l, y_l] for p_l, x_l, y_l in idx_sup ] samples_class = np.zeros(args.n_classes) for c in np.unique(train_labeled_gt): samples_class[c - 1] = np.count_nonzero(train_labeled_gt == c) writer.add_text('Labeled samples per class', str(samples_class)) print('Labeled samples per class: ' + str(samples_class)) val_dataset = HyperX_patches(train_img, train_gt, idx_val, labeled='Val', **vars(args)) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size) train_dataset = HyperX_patches(train_img, train_gt, idx_sup, labeled=True, **vars(args)) train_loader = data.DataLoader( train_dataset, batch_size=args.batch_size, #pin_memory=True, num_workers=5, shuffle=True, drop_last=True) amount_labeled = idx_sup.shape[0] else: train_labeled_gt, val_gt = utils.sample_gt(train_gt, 0.95, mode=args.sampling_mode) val_dataset = HyperX(img, val_gt, labeled='Val', **vars(args)) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size) writer.add_text( 'Amount of labeled training samples', "{} samples selected (over {})".format( np.count_nonzero(train_labeled_gt), np.count_nonzero(train_gt))) samples_class = np.zeros(args.n_classes) for c in np.unique(train_labeled_gt): samples_class[c - 1] = np.count_nonzero(train_labeled_gt == c) writer.add_text('Labeled samples per class', str(samples_class)) train_dataset = HyperX(img, train_labeled_gt, labeled=True, **vars(args)) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=5, shuffle=True, drop_last=True) utils.display_predictions(convert_to_color(train_labeled_gt), vis, writer=writer, caption="Labeled train ground truth") utils.display_predictions(convert_to_color(val_gt), vis, writer=writer, caption="Validation ground truth") amount_labeled = np.count_nonzero(train_labeled_gt) args.iterations = amount_labeled // args.batch_size args.total_steps = args.iterations * args.epochs args.scheduler = get_cosine_schedule_with_warmup( optimizer, args.warmup * args.iterations, args.total_steps) if args.class_balancing: weights_balance = utils.compute_imf_weights(train_gt, len(label_values), args.ignored_labels) args.weights = torch.from_numpy(weights_balance[1:]) args.weights = args.weights.to(torch.float32) else: weights = torch.ones(args.n_classes) #weights[torch.LongTensor(args.ignored_labels)] = 0 args.weights = weights args.weights = args.weights.to(args.device) criterion = nn.CrossEntropyLoss(weight=args.weights) loss_val = nn.CrossEntropyLoss(weight=args.weights) print(args) print("Network :") writer.add_text('Arguments', str(args)) with torch.no_grad(): for input, _ in train_loader: break #summary(model.to(device), input.size()[1:]) #writer.add_graph(model.to(device), input) # We would like to use device=hyperparams['device'] altough we have # to wait for torchsummary to be fixed first. if args.load_file is not None: model.load_state_dict(torch.load(args.load_file)) model.zero_grad() try: train(model, optimizer, criterion, loss_val, train_loader, writer, args, val_loader=val_loader, display=vis) except KeyboardInterrupt: # Allow the user to stop the training pass if args.sampling_mode == 'nalepa': probabilities = test(model, test_img, args) else: probabilities = test(model, img, args) prediction = np.argmax(probabilities, axis=-1) run_results = utils.metrics(prediction, test_gt, ignored_labels=args.ignored_labels, n_classes=args.n_classes) mask = np.zeros(test_gt.shape, dtype='bool') for l in args.ignored_labels: mask[test_gt == l] = True prediction += 1 prediction[mask] = 0 color_prediction = convert_to_color(prediction) utils.display_predictions(color_prediction, vis, gt=convert_to_color(test_gt), writer=writer, caption="Prediction vs. test ground truth") utils.show_results(run_results, vis, writer=writer, label_values=label_values) writer.close() return run_results
scaler = sklearn.preprocessing.StandardScaler() X_train = scaler.fit_transform(X_train) class_weight = 'balanced' if CLASS_BALANCING else None clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight, learning_rate='optimal', tol=1e-3, average=10) clf.fit(X_train, y_train) save_model(clf, MODEL, DATASET) prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS))) prediction = prediction.reshape(img.shape[:2]) else: # 自定义算法 # Neural network model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams) if CLASS_BALANCING: # weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS) hyperparams['weights'] = torch.from_numpy(weights) # Split train set in train/val # train_gt, val_gt = sample_gt(train_gt, 0.9, mode='random') # Generate the dataset train_dataset = HyperX(img, train_gt, **hyperparams) train_loader = data.DataLoader(train_dataset, batch_size=hyperparams['batch_size'], pin_memory=hyperparams['device'], shuffle=True) # val_dataset = HyperX(img, val_gt, **hyperparams) # val_loader = data.DataLoader(val_dataset, # pin_memory=hyperparams['device'], # batch_size=hyperparams['batch_size']) print(hyperparams)
def train_model(img, gt, hyperparams): """ Function for model training. 1) Data sampling into a training, a validation and a test set. 2) Training a chosen model. 3) Model evaluation. Arguments: img - dataset (hyperspectral image) gt - ground truth (labels) hyperparams - parameters of training SVM_GRID_PARAMS - parameters for SVM (if used) FOLDER - a path for datasets DATASET - name of the used dataset set_parameters: option for loading a specific training and test set preprocessing_parameters: parameters of preprocessing """ print("img.shape: {}".format(img.shape)) print("gt.shape: {}".format(gt.shape)) # all images should have 113 bands assert(img.shape[2] == 113) viz = None results = [] # run the experiment several times for run in range(hyperparams['runs']): ############################################################################# # Create a training and a test set if hyperparams['train_gt'] is not None and hyperparams['test_gt'] is not None: train_gt = open_file(hyperparams['train_gt']) test_gt = open_file(hyperparams['test_gt']) elif hyperparams['train_gt'] is not None: train_gt = open_file(hyperparams['train_gt']) test_gt = np.copy(gt) w, h = test_gt.shape test_gt[(train_gt > 0)[:w, :h]] = 0 elif hyperparams['test_gt'] is not None: test_gt = open_file(hyperparams['test_gt']) else: # Choose type of data sampling if hyperparams['sampling_mode'] == 'uniform': train_gt, test_gt = select_subset(gt, hyperparams['training_sample']) check_split_correctness(gt, train_gt, test_gt, hyperparams['n_classes']) elif hyperparams['sampling_mode'] == 'fixed': # load fixed sets from a given path train_gt, test_gt = get_fixed_sets(run, hyperparams['sample_path'], hyperparams['dataset']) check_split_correctness(gt, train_gt, test_gt, hyperparams['n_classes'], 'fixed') else: train_gt, test_gt = sample_gt(gt, hyperparams['training_sample'], mode=hyperparams['sampling_mode']) print("{} samples selected (over {})".format(np.count_nonzero(train_gt), np.count_nonzero(gt))) print("Running an experiment with the {} model".format(hyperparams['model']), "run {}/{}".format(run + 1, hyperparams['runs'])) ####################################################################### # Train a model if hyperparams['model'] == 'SVM_grid': print("Running a grid search SVM") # Grid search SVM (linear and RBF) X_train, y_train = build_dataset(img, train_gt, ignored_labels=hyperparams['ignored_labels']) class_weight = 'balanced' if hyperparams['class_balancing'] else None clf = sklearn.svm.SVC(class_weight=class_weight) clf = sklearn.model_selection.GridSearchCV(clf, hyperparams['svm_grid_params'], verbose=5, n_jobs=4) clf.fit(X_train, y_train) print("SVM best parameters : {}".format(clf.best_params_)) prediction = clf.predict(img.reshape(-1, hyperparams['n_bands'])) save_model(clf, hyperparams['model'], hyperparams['dataset'], hyperparams['rdir']) prediction = prediction.reshape(img.shape[:2]) elif hyperparams['model'] == 'SVM': X_train, y_train = build_dataset(img, train_gt, ignored_labels=hyperparams['ignored_labels']) class_weight = 'balanced' if hyperparams['class_balancing'] else None clf = sklearn.svm.SVC(class_weight=class_weight) clf.fit(X_train, y_train) save_model(clf, hyperparams['model'], hyperparams['dataset'], hyperparams['rdir']) prediction = clf.predict(img.reshape(-1, hyperparams['n_bands'])) prediction = prediction.reshape(img.shape[:2]) elif hyperparams['model'] == 'SGD': X_train, y_train = build_dataset(img, train_gt, ignored_labels=hyperparams['ignored_labels']) X_train, y_train = sklearn.utils.shuffle(X_train, y_train) scaler = sklearn.preprocessing.StandardScaler() X_train = scaler.fit_transform(X_train) class_weight = 'balanced' if hyperparams['class_balancing'] else None clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight, learning_rate='optimal', tol=1e-3, average=10) clf.fit(X_train, y_train) save_model(clf, hyperparams['model'], hyperparams['dataset'], hyperparams['rdir']) prediction = clf.predict(scaler.transform(img.reshape(-1, hyperparams['n_bands']))) prediction = prediction.reshape(img.shape[:2]) elif hyperparams['model'] == 'nearest': X_train, y_train = build_dataset(img, train_gt, ignored_labels=hyperparams['ignored_labels']) X_train, y_train = sklearn.utils.shuffle(X_train, y_train) class_weight = 'balanced' if hyperparams['class_balancing'] else None clf = sklearn.neighbors.KNeighborsClassifier(weights='distance') clf = sklearn.model_selection.GridSearchCV(clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4) clf.fit(X_train, y_train) clf.fit(X_train, y_train) save_model(clf, hyperparams['model'], hyperparams['dataset'], hyperparams['rdir']) prediction = clf.predict(img.reshape(-1, hyperparams['n_bands'])) prediction = prediction.reshape(img.shape[:2]) else: # Neural network model, optimizer, loss, hyperparams = get_model(hyperparams['model'], **hyperparams) if hyperparams['class_balancing']: weights = compute_imf_weights(train_gt, hyperparams['n_classes'], hyperparams['ignored_labels']) hyperparams['weights'] = torch.from_numpy(weights) # Split train set in train/val if hyperparams['sampling_mode'] in {'uniform', 'fixed'}: train_gt, val_gt = select_subset(train_gt, 0.95) else: train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random') # Generate the dataset train_dataset = HyperX(img, train_gt, **hyperparams) train_loader = data.DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True) val_dataset = HyperX(img, val_gt, **hyperparams) val_loader = data.DataLoader(val_dataset, batch_size=hyperparams['batch_size']) print(hyperparams) print("Network :") with torch.no_grad(): for input, _ in train_loader: break summary(model.to(hyperparams['device']), input.size()[1:]) # We would like to use device=hyperparams['device'] altough we have # to wait for torchsummary to be fixed first. if hyperparams['checkpoint'] is not None: model.load_state_dict(torch.load(hyperparams['checkpoint'])) try: train(model, optimizer, loss, train_loader, hyperparams['epoch'], scheduler=hyperparams['scheduler'], device=hyperparams['device'], supervision=hyperparams['supervision'], val_loader=val_loader, display=viz, rdir=hyperparams['rdir'], model_name=hyperparams['model'], preprocessing=hyperparams['preprocessing']['type'], run=run) except KeyboardInterrupt: # Allow the user to stop the training pass probabilities = test(model, img, hyperparams) prediction = np.argmax(probabilities, axis=-1) ####################################################################### # Evaluate the model # If test set is not empty if(np.unique(test_gt).shape[0] > 1): run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=hyperparams['n_classes']) mask = np.zeros(gt.shape, dtype='bool') for l in hyperparams['ignored_labels']: mask[gt == l] = True prediction[mask] = 0