コード例 #1
0
# run the experiment several times
for run in range(N_RUNS):
    # 根据命令行输入参数确定train_gt和test_gt
    if TRAIN_GT is not None and TEST_GT is not None:
        train_gt = open_file(TRAIN_GT)
        test_gt = open_file(TEST_GT)
    elif TRAIN_GT is not None:
        train_gt = open_file(TRAIN_GT)
        test_gt = np.copy(gt)
        w, h = test_gt.shape
        test_gt[(train_gt > 0)[:w,:h]] = 0
    elif TEST_GT is not None:
        test_gt = open_file(TEST_GT)
    else:
    # Sample random training spectra    随机训练光谱样本(有训练集,有测试集)
        train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE)

    print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                 np.count_nonzero(gt)))
    print("Running an experiment with the {} model".format(MODEL),
          "run {}/{}".format(run + 1, N_RUNS))

    display_predictions(convert_to_color(train_gt), viz, caption="Train ground truth")
    display_predictions(convert_to_color(test_gt), viz, caption="Test ground truth")

    if MODEL == 'SVM_grid':
        print("Running a grid search SVM")
        # Grid search SVM (linear and RBF)
        X_train, y_train = build_dataset(img, train_gt,
                                         ignored_labels=IGNORED_LABELS)
        class_weight = 'balanced' if CLASS_BALANCING else None
コード例 #2
0
ファイル: main.py プロジェクト: githubltqc/CBW
 if TRAIN_GT is not None and TEST_GT is not None:
     train_gt = open_file(TRAIN_GT)
     test_gt = open_file(TEST_GT)
 elif TRAIN_GT is not None:
     train_gt = open_file(TRAIN_GT)
     test_gt = np.copy(gt)
     w, h = test_gt.shape
     test_gt[(train_gt > 0)[:w, :h]] = 0
 elif TEST_GT is not None:
     test_gt = open_file(TEST_GT)
 else:
     # Sample random training spectral
     gt_ = gt[(PATCH_SIZE // 2):-(PATCH_SIZE // 2),
              (PATCH_SIZE // 2):-(PATCH_SIZE // 2)]
     train_gt, test_gt = sample_gt(gt_,
                                   SAMPLE_PERCENTAGE,
                                   mode=SAMPLING_MODE)
     # ----------------------------------------------------------------------------------
     mask = np.zeros_like(gt)
     for l in set(hyperparams['ignored_labels']):
         mask[gt == l] = 0
     x_pos, y_pos = np.nonzero(train_gt)
     indices = np.array([(x, y) for x, y in zip(x_pos, y_pos)])
     for x, y in indices:
         if mask[x + PATCH_SIZE // 2, y + PATCH_SIZE // 2] is not 0:
             mask[x + PATCH_SIZE // 2,
                  y + PATCH_SIZE // 2] = gt[x + PATCH_SIZE // 2,
                                            y + PATCH_SIZE // 2]
     train_gt = mask
     # ----------------------------------------------------------------------------------
     test_gt = gt  # all of sample to be test sample
コード例 #3
0
ファイル: datasplit.py プロジェクト: shangsw/HPDM-SPRN
    
    dataset_names = [v['name'] if 'name' in v.keys() else k for k, v in DATASETS_CONFIG.items()]
                     
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default=None, choices=dataset_names,
                        help="Dataset to use.") 
    parser.add_argument('--folder', type=str, help="Folder where to store the "
                    "datasets (defaults to the current working directory).",
                    default="../Datasets/")
    parser.add_argument('--training_sample', type=float, default=10,
                        help="Percentage of samples to use for training (default: 10%%)")
    parser.add_argument('--sampling_mode', type=str, help="Sampling mode"
                    " (random sampling or disjoint, default: random)",
                    default='random')
    
    args = parser.parse_args()
    
    img, gt, _ = get_dataset(args.dataset, args.folder)
    train_gt, test_gt = sample_gt(gt, args.training_sample, mode=args.sampling_mode)
    train_gt, val_gt = sample_gt(train_gt, 0.5)
    #save file
    train_gt_path = args.folder + '/' + args.dataset + '/train_gt.npy'
    val_gt_path = args.folder + '/' + args.dataset + '/val_gt.npy'
    test_gt_path = args.folder + '/' + args.dataset + '/test_gt.npy'
    np.save(train_gt_path, train_gt)
    np.save(val_gt_path, val_gt)
    np.save(test_gt_path, test_gt)
    print("Done!")


コード例 #4
0
def main(raw_args=None):
    parser = argparse.ArgumentParser(
        description="Hyperspectral image classification with FixMatch")
    parser.add_argument(
        '--patch_size',
        type=int,
        default=5,
        help='Size of patch around each pixel taken for classification')
    parser.add_argument(
        '--center_pixel',
        action='store_false',
        help=
        'use if you only want to consider the label of the center pixel of a patch'
    )
    parser.add_argument('--batch_size',
                        type=int,
                        default=10,
                        help='Size of each batch for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        help='number of total epochs of training to run')
    parser.add_argument('--dataset',
                        type=str,
                        default='Salinas',
                        help='Name of dataset to run, Salinas or PaviaU')
    parser.add_argument('--cuda',
                        type=int,
                        default=-1,
                        help='what CUDA device to run on, -1 defaults to cpu')
    parser.add_argument('--warmup',
                        type=float,
                        default=0,
                        help='warmup epochs')
    parser.add_argument('--save',
                        action='store_true',
                        help='use to save model weights when running')
    parser.add_argument(
        '--test_stride',
        type=int,
        default=1,
        help='length of stride when sliding patch window over image for testing'
    )
    parser.add_argument(
        '--sampling_percentage',
        type=float,
        default=0.3,
        help=
        'percentage of dataset to sample for training (labeled and unlabeled included)'
    )
    parser.add_argument(
        '--sampling_mode',
        type=str,
        default='nalepa',
        help='how to sample data, disjoint, random, nalepa, or fixed')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='initial learning rate')
    parser.add_argument('--alpha',
                        type=float,
                        default=1.0,
                        help='beta distribution range')
    parser.add_argument(
        '--class_balancing',
        action='store_false',
        help='use to balance weights according to ratio in dataset')
    parser.add_argument(
        '--checkpoint',
        type=str,
        default=None,
        help='use to load model weights from a certain directory')
    #Augmentation arguments
    parser.add_argument('--flip_augmentation',
                        action='store_true',
                        help='use to flip augmentation data for use')
    parser.add_argument('--radiation_augmentation',
                        action='store_true',
                        help='use to radiation noise data for use')
    parser.add_argument('--mixture_augmentation',
                        action='store_true',
                        help='use to mixture noise data for use')
    parser.add_argument('--pca_augmentation',
                        action='store_true',
                        help='use to pca augment data for use')
    parser.add_argument(
        '--pca_strength',
        type=float,
        default=1.0,
        help='Strength of the PCA augmentation, defaults to 1.')
    parser.add_argument('--cutout_spatial',
                        action='store_true',
                        help='use to cutout spatial for data augmentation')
    parser.add_argument('--cutout_spectral',
                        action='store_true',
                        help='use to cutout spectral for data augmentation')
    parser.add_argument(
        '--augmentation_magnitude',
        type=int,
        default=1,
        help=
        'Magnitude of augmentation (so far only for cutout). Defualts to 1, min 1 and max 10.'
    )
    parser.add_argument('--spatial_combinations',
                        action='store_true',
                        help='use to spatial combine for data augmentation')
    parser.add_argument('--spectral_mean',
                        action='store_true',
                        help='use to spectal mean for data augmentation')
    parser.add_argument(
        '--moving_average',
        action='store_true',
        help='use to sprectral moving average for data augmentation')

    parser.add_argument('--results',
                        type=str,
                        default='results',
                        help='where to save results to (default results)')
    parser.add_argument('--save_dir',
                        type=str,
                        default='/saves/',
                        help='where to save models to (default /saves/)')
    parser.add_argument('--data_dir',
                        type=str,
                        default='/data/',
                        help='where to fetch data from (default /data/)')
    parser.add_argument('--load_file',
                        type=str,
                        default=None,
                        help='wihch file to load weights from (default None)')
    parser.add_argument(
        '--fold',
        type=int,
        default=0,
        help='Which fold to sample from if using Nalepas validation scheme')
    parser.add_argument(
        '--sampling_fixed',
        type=str,
        default='True',
        help=
        'Use to sample a fixed amount of samples for each class from Nalepa sampling'
    )
    parser.add_argument(
        '--samples_per_class',
        type=int,
        default=10,
        help=
        'Amount of samples to sample for each class when sampling a fixed amount. Defaults to 10.'
    )

    parser.add_argument(
        '--supervision',
        type=str,
        default='full',
        help=
        'check this more, use to make us of all labeled or not, full or semi')

    args = parser.parse_args(raw_args)

    device = utils.get_device(args.cuda)
    args.device = device

    #vis = visdom.Visdom()
    vis = None

    tensorboard_dir = str(args.results + '/' +
                          datetime.datetime.now().strftime("%m-%d-%X"))

    os.makedirs(tensorboard_dir, exist_ok=True)
    writer = SummaryWriter(tensorboard_dir)

    if args.sampling_mode == 'nalepa':
        train_img, train_gt, test_img, test_gt, label_values, ignored_labels, rgb_bands, palette = get_patch_data(
            args.dataset,
            args.patch_size,
            target_folder=args.data_dir,
            fold=args.fold)
        args.n_bands = train_img.shape[-1]
    else:
        img, gt, label_values, ignored_labels, rgb_bands, palette = get_dataset(
            args.dataset, target_folder=args.data_dir)
        args.n_bands = img.shape[-1]

    args.n_classes = len(label_values) - len(ignored_labels)
    args.ignored_labels = ignored_labels

    if palette is None:
        # Generate color palette
        palette = {0: (0, 0, 0)}
        for k, color in enumerate(
                sns.color_palette("hls",
                                  len(label_values) - 1)):
            palette[k + 1] = tuple(
                np.asarray(255 * np.array(color), dtype='uint8'))
    invert_palette = {v: k for k, v in palette.items()}

    def convert_to_color(x):
        return utils.convert_to_color_(x, palette=palette)

    def convert_from_color(x):
        return utils.convert_from_color_(x, palette=invert_palette)

    if args.sampling_mode == 'nalepa':
        print("{} samples selected (over {})".format(
            np.count_nonzero(train_gt),
            np.count_nonzero(train_gt) + np.count_nonzero(test_gt)))
        writer.add_text(
            'Amount of training samples',
            "{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                   np.count_nonzero(test_gt)))

        utils.display_predictions(convert_to_color(test_gt),
                                  vis,
                                  writer=writer,
                                  caption="Test ground truth")
    else:
        train_gt, test_gt = utils.sample_gt(gt,
                                            args.sampling_percentage,
                                            mode=args.sampling_mode)
        print("{} samples selected (over {})".format(
            np.count_nonzero(train_gt), np.count_nonzero(gt)))
        writer.add_text(
            'Amount of training samples',
            "{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                   np.count_nonzero(gt)))

        utils.display_predictions(convert_to_color(train_gt),
                                  vis,
                                  writer=writer,
                                  caption="Train ground truth")
        utils.display_predictions(convert_to_color(test_gt),
                                  vis,
                                  writer=writer,
                                  caption="Test ground truth")

    model = HamidaEtAl(args.n_bands,
                       args.n_classes,
                       patch_size=args.patch_size)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=0.0005)
    #loss_labeled = nn.CrossEntropyLoss(weight=weights)
    #loss_unlabeled = nn.CrossEntropyLoss(weight=weights, reduction='none')

    if args.sampling_mode == 'nalepa':
        #Get fixed amount of random samples for validation
        idx_sup, idx_val, idx_unsup = get_pixel_idx(train_img, train_gt,
                                                    args.ignored_labels,
                                                    args.patch_size)

        if args.sampling_fixed == 'True':
            unique_labels = np.zeros(len(label_values))
            new_idx_sup = []
            index = 0
            for p, x, y in idx_sup:
                label = train_gt[p, x, y]
                if unique_labels[label] < args.samples_per_class:
                    unique_labels[label] += 1
                    new_idx_sup.append([p, x, y])
                    np.delete(idx_sup, index)
                index += 1
            idx_unsup = np.concatenate((idx_sup, idx_unsup))
            idx_sup = np.asarray(new_idx_sup)

        writer.add_text(
            'Amount of labeled training samples',
            "{} samples selected (over {})".format(idx_sup.shape[0],
                                                   np.count_nonzero(train_gt)))
        train_labeled_gt = [
            train_gt[p_l, x_l, y_l] for p_l, x_l, y_l in idx_sup
        ]

        samples_class = np.zeros(args.n_classes)
        for c in np.unique(train_labeled_gt):
            samples_class[c - 1] = np.count_nonzero(train_labeled_gt == c)
        writer.add_text('Labeled samples per class', str(samples_class))
        print('Labeled samples per class: ' + str(samples_class))

        val_dataset = HyperX_patches(train_img,
                                     train_gt,
                                     idx_val,
                                     labeled='Val',
                                     **vars(args))
        val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)

        train_dataset = HyperX_patches(train_img,
                                       train_gt,
                                       idx_sup,
                                       labeled=True,
                                       **vars(args))
        train_loader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            #pin_memory=True, num_workers=5,
            shuffle=True,
            drop_last=True)

        amount_labeled = idx_sup.shape[0]
    else:
        train_labeled_gt, val_gt = utils.sample_gt(train_gt,
                                                   0.95,
                                                   mode=args.sampling_mode)

        val_dataset = HyperX(img, val_gt, labeled='Val', **vars(args))
        val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)

        writer.add_text(
            'Amount of labeled training samples',
            "{} samples selected (over {})".format(
                np.count_nonzero(train_labeled_gt),
                np.count_nonzero(train_gt)))
        samples_class = np.zeros(args.n_classes)
        for c in np.unique(train_labeled_gt):
            samples_class[c - 1] = np.count_nonzero(train_labeled_gt == c)
        writer.add_text('Labeled samples per class', str(samples_class))

        train_dataset = HyperX(img,
                               train_labeled_gt,
                               labeled=True,
                               **vars(args))
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       pin_memory=True,
                                       num_workers=5,
                                       shuffle=True,
                                       drop_last=True)

        utils.display_predictions(convert_to_color(train_labeled_gt),
                                  vis,
                                  writer=writer,
                                  caption="Labeled train ground truth")
        utils.display_predictions(convert_to_color(val_gt),
                                  vis,
                                  writer=writer,
                                  caption="Validation ground truth")

        amount_labeled = np.count_nonzero(train_labeled_gt)

    args.iterations = amount_labeled // args.batch_size
    args.total_steps = args.iterations * args.epochs
    args.scheduler = get_cosine_schedule_with_warmup(
        optimizer, args.warmup * args.iterations, args.total_steps)

    if args.class_balancing:
        weights_balance = utils.compute_imf_weights(train_gt,
                                                    len(label_values),
                                                    args.ignored_labels)
        args.weights = torch.from_numpy(weights_balance[1:])
        args.weights = args.weights.to(torch.float32)
    else:
        weights = torch.ones(args.n_classes)
        #weights[torch.LongTensor(args.ignored_labels)] = 0
        args.weights = weights

    args.weights = args.weights.to(args.device)
    criterion = nn.CrossEntropyLoss(weight=args.weights)
    loss_val = nn.CrossEntropyLoss(weight=args.weights)

    print(args)
    print("Network :")
    writer.add_text('Arguments', str(args))
    with torch.no_grad():
        for input, _ in train_loader:
            break
        #summary(model.to(device), input.size()[1:])
        #writer.add_graph(model.to(device), input)
        # We would like to use device=hyperparams['device'] altough we have
        # to wait for torchsummary to be fixed first.

    if args.load_file is not None:
        model.load_state_dict(torch.load(args.load_file))
    model.zero_grad()

    try:
        train(model,
              optimizer,
              criterion,
              loss_val,
              train_loader,
              writer,
              args,
              val_loader=val_loader,
              display=vis)
    except KeyboardInterrupt:
        # Allow the user to stop the training
        pass

    if args.sampling_mode == 'nalepa':
        probabilities = test(model, test_img, args)
    else:
        probabilities = test(model, img, args)
    prediction = np.argmax(probabilities, axis=-1)

    run_results = utils.metrics(prediction,
                                test_gt,
                                ignored_labels=args.ignored_labels,
                                n_classes=args.n_classes)

    mask = np.zeros(test_gt.shape, dtype='bool')
    for l in args.ignored_labels:
        mask[test_gt == l] = True
    prediction += 1
    prediction[mask] = 0

    color_prediction = convert_to_color(prediction)
    utils.display_predictions(color_prediction,
                              vis,
                              gt=convert_to_color(test_gt),
                              writer=writer,
                              caption="Prediction vs. test ground truth")

    utils.show_results(run_results,
                       vis,
                       writer=writer,
                       label_values=label_values)

    writer.close()

    return run_results
コード例 #5
0
# run the experiment several times
for run in range(N_RUNS):
    if TRAIN_GT is not None and TEST_GT is not None:
        train_gt = open_file(TRAIN_GT)['TRLabel']
        test_gt = open_file(TEST_GT)['TSLabel']
    elif TRAIN_GT is not None:
        train_gt = open_file(TRAIN_GT)
        test_gt = np.copy(gt)
        w, h = test_gt.shape
        test_gt[(train_gt > 0)[:w, :h]] = 0
    elif TEST_GT is not None:
        test_gt = open_file(TEST_GT)
    else:
        # Sample random training spectra
        train_gt, test_gt = sample_gt(gt,
                                      SAMPLE_PERCENTAGE,
                                      mode=SAMPLING_MODE)
    print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                 np.count_nonzero(gt)))
    print(
        "Running an experiment with the {} model".format(MODEL),
        "run {}/{}".format(run + 1, N_RUNS),
    )

    display_predictions(convert_to_color(train_gt),
                        viz,
                        caption="Train ground truth")
    display_predictions(convert_to_color(test_gt),
                        viz,
                        caption="Test ground truth")
    # delete
コード例 #6
0
for run in range(N_RUNS):
    if TRAIN_GT is not None and TEST_GT is not None:
        print("Using existing train/test split...")
        train_gt = open_file(TRAIN_GT)['train_gt']
        test_gt = open_file(TEST_GT)['test_gt']
    elif TRAIN_GT is not None:
        train_gt = open_file(TRAIN_GT)
        test_gt = np.copy(gt)
        w, h = test_gt.shape
        test_gt[(train_gt > 0)[:w, :h]] = 0
    elif TEST_GT is not None:
        test_gt = open_file(TEST_GT)
    else:
        # Sample random training spectra
        train_gt, test_gt = sample_gt(gt,
                                      SAMPLE_PERCENTAGE,
                                      mode=SAMPLING_MODE)
        scipy.io.savemat("test.mat", {'test_gt': test_gt})

    print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                 np.count_nonzero(gt)))
    print(
        "Running an experiment with the {} model".format(MODEL),
        "run {}/{}".format(run + 1, N_RUNS),
    )

    display_predictions(convert_to_color(train_gt),
                        viz,
                        caption="Train ground truth")
    display_predictions(convert_to_color(test_gt),
                        viz,
コード例 #7
0
        clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
        clf = sklearn.model_selection.GridSearchCV(
            clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4)
        clf.fit(X_train, y_train)
        clf.fit(X_train, y_train)
        save_model(clf, MODEL, DATASET)
        prediction = clf.predict(img.reshape(-1, N_BANDS))
        prediction = prediction.reshape(img.shape[:2])
    else:
        # Neural network
        model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
        if CLASS_BALANCING:
            weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
            hyperparams['weights'] = torch.from_numpy(weights)
        # Split train set in train/val
        train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
        # Generate the dataset
        train_dataset = HyperX(img, train_gt, **hyperparams)  # 生成用于训练的数据方块和标签
        train_loader = data.DataLoader(
            train_dataset,
            batch_size=hyperparams['batch_size'],
            #pin_memory=hyperparams['device'],
            shuffle=True)
        val_dataset = HyperX(img, val_gt, **hyperparams)  # 生成用于测试的数据方块和标签
        val_loader = data.DataLoader(
            val_dataset,
            #pin_memory=hyperparams['device'],
            batch_size=hyperparams['batch_size'])

        print(hyperparams)
        print("Network :")
コード例 #8
0
def main(params):
    RUNS = 4
    MX_ITER = 1000000000
    os.environ["CUDA_VISIBLE_DEVICES"] = params.GPU

    device = torch.device("cuda: 0")
    print(torch.cuda.device_count())

    new_path = str(params.DATASET) + '/Best/' + '_'.join([
        str(params.SAMPLE_PERCENTAGE),
        str(params.DHCN_LAYERS),
        str(params.CONV_SIZE),
        str(params.ROT),
        str(params.MIRROR),
        str(params.H_MIRROR)
    ]) + '/'
    if os.path.exists(new_path):
        RUNS = 4 - len(os.listdir(new_path))

    if RUNS == 0:
        return

    for _ in range(RUNS):

        start_time = time.time()

        SAMPLE_PERCENTAGE = params.SAMPLE_PERCENTAGE
        DATASET = params.DATASET
        DHCN_LAYERS = params.DHCN_LAYERS
        CONV_SIZE = params.CONV_SIZE
        H_MIRROR = params.H_MIRROR
        LR = 1e-4

        save_path = str(DATASET) + '/tmp' + str(params.GPU) + '_abc/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        img, gt, LABEL_VALUES, IGNORED_LABELS, _, _ = get_dataset(
            DATASET, "Datasets/")
        X, Y = get_originate_dataset(DATASET, "Datasets/")

        img = img[:, :, list(range(0, 102, 3))]

        N_CLASSES = len(LABEL_VALUES)
        INPUT_SIZE = np.shape(img)[-1]
        # train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode='fixed')
        train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode='fixed')
        #train_gt = sample_gt2(X, Y, train_gt, test_gt, SAMPLE_PERCENTAGE)

        pseudo_labelpath = 'PaviaU/pseudo_labels/pseudo_labels3/pseudo_labels3.npy'
        if not os.path.exists(pseudo_labelpath):
            newdir = 'PaviaU/pseudo_labels/pseudo_labels3/'
            if not os.path.exists(newdir):
                os.makedirs(newdir)
            _, pseudo_labels3 = sample_gt3(X, Y, train_gt, test_gt,
                                           SAMPLE_PERCENTAGE)
            np.save(pseudo_labelpath, pseudo_labels3)
        else:
            pseudo_labels3 = np.load(pseudo_labelpath)

        trainnum = np.sum(train_gt > 0)
        print("trainnum:%d" % (trainnum))
        INPUT_DATA = pre_data(img, train_gt, params, gt, pseudo_labels3)

        model_DHCN = DHCN(input_size=INPUT_SIZE,
                          embed_size=INPUT_SIZE,
                          densenet_layer=DHCN_LAYERS,
                          output_size=N_CLASSES,
                          conv_size=CONV_SIZE,
                          batch_norm=False).to(device)
        optimizer_DHCN = torch.optim.Adam(model_DHCN.parameters(),
                                          lr=LR,
                                          betas=(0.5, 0.999),
                                          weight_decay=1e-4)
        model_DHCN = nn.DataParallel(model_DHCN)
        loss_ce = nn.CrossEntropyLoss().to(device)
        loss_bce = nn.BCELoss().to(device)

        best_ACC, tmp_epoch, tmp_count, tmp_rate, recode_reload, reload_model = 0.0, 0, 0, LR, {}, False
        max_tmp_count = 300

        for epoch in range(MX_ITER):
            if epoch % 100 == 0:
                print("epoch: %d" % (epoch))
            current_time = time.time()
            if current_time - start_time > 28800:

                print('Training end')
                old_path = save_path + 'save_' + str(tmp_epoch) + '_' + str(
                    round(best_ACC, 2)) + '.pth'
                new_path = str(DATASET) + '/Best/' + '_'.join([
                    str(params.SAMPLE_PERCENTAGE),
                    str(params.DHCN_LAYERS),
                    str(params.CONV_SIZE),
                    str(params.ROT),
                    str(params.MIRROR),
                    str(params.H_MIRROR)
                ]) + '/'
                if not os.path.exists(new_path):
                    os.makedirs(new_path)
                shutil.move(old_path, new_path)
                new_name = '_'.join([
                    str(SAMPLE_PERCENTAGE),
                    str(DHCN_LAYERS),
                    str(CONV_SIZE),
                    str(params.ROT),
                    str(params.MIRROR),
                    str(params.H_MIRROR),
                    str(round(best_ACC, 2)) + '.pth'
                ])
                os.rename(
                    new_path + 'save_' + str(tmp_epoch) + '_' +
                    str(round(best_ACC, 2)) + '.pth', new_path + new_name)

                shutil.rmtree(save_path)

                break

            if reload_model == True:

                if str(tmp_epoch) in recode_reload:

                    recode_reload[str(tmp_epoch)] += 1
                    tmp_rate = tmp_rate * 0.1
                    if tmp_rate < 1e-6:

                        print('Training end')
                        old_path = save_path + 'save_' + str(
                            tmp_epoch) + '_' + str(round(best_ACC, 2)) + '.pth'
                        new_path = str(DATASET) + '/Best/' + '_'.join([
                            str(params.SAMPLE_PERCENTAGE),
                            str(params.DHCN_LAYERS),
                            str(params.CONV_SIZE),
                            str(params.ROT),
                            str(params.MIRROR),
                            str(params.H_MIRROR)
                        ]) + '/'
                        if not os.path.exists(new_path):
                            os.makedirs(new_path)
                        shutil.move(old_path, new_path)
                        new_name = '_'.join([
                            str(SAMPLE_PERCENTAGE),
                            str(DHCN_LAYERS),
                            str(CONV_SIZE),
                            str(params.ROT),
                            str(params.MIRROR),
                            str(params.H_MIRROR),
                            str(round(best_ACC, 2)) + '.pth'
                        ])
                        os.rename(
                            new_path + 'save_' + str(tmp_epoch) + '_' +
                            str(round(best_ACC, 2)) + '.pth',
                            new_path + new_name)

                        shutil.rmtree(save_path)

                        break

                    print('learning decay: ', str(tmp_epoch), tmp_rate)
                    for param_group in optimizer_DHCN.param_groups:
                        param_group['lr'] = param_group['lr'] * 0.1

                else:

                    recode_reload[str(tmp_epoch)] = 1
                    print('learning keep: ', tmp_epoch)

                pretrained_model = save_path + 'save_' + str(
                    tmp_epoch) + '_' + str(round(best_ACC, 2)) + '.pth'
                pretrain = torch.load(pretrained_model)
                model_DHCN.load_state_dict(pretrain['state_dict_DHCN'])
                reload_model = False

            model_DHCN.train()

            loss_supervised, loss_self, loss_distill, loss_distill2 = 0.0, 0.0, 0.0, 0.0

            for TRAIN_IMG, TRAIN_Y, TRAIN_PL in zip(INPUT_DATA[0],
                                                    INPUT_DATA[1],
                                                    INPUT_DATA[2]):
                scores, _ = model_DHCN(TRAIN_IMG.to(device))
                for k_Layer in range(DHCN_LAYERS + 1):
                    for i_num, (k_scores, k_TRAIN_Y) in enumerate(
                            zip(scores[k_Layer], TRAIN_Y)):
                        k_TRAIN_Y = k_TRAIN_Y.to(device)
                        loss_supervised += loss_ce(
                            k_scores.permute(1, 2, 0)[k_TRAIN_Y > 0],
                            k_TRAIN_Y[k_TRAIN_Y > 0])
                        for id_layer, k_TRAIN_PL in enumerate(TRAIN_PL):
                            k_TRAIN_PL = k_TRAIN_PL.to(device)
                            if (k_TRAIN_PL[i_num].sum(-1) > 1).sum() > 0:
                                i_num
                                #loss_distill += (1 / float(id_layer + 1)) * loss_bce(k_scores.permute(1,2,0).sigmoid()[k_TRAIN_PL[i_num].sum(-1) > 0], k_TRAIN_PL[i_num][k_TRAIN_PL[i_num].sum(-1) > 0])
                            else:
                                onehot2label = torch.topk(
                                    k_TRAIN_PL[i_num], k=1,
                                    dim=-1)[1].squeeze(-1)
                                loss_self += (
                                    1 / float(id_layer + 1)) * loss_ce(
                                        k_scores.permute(1, 2,
                                                         0)[onehot2label > 0],
                                        onehot2label[onehot2label > 0])
                                #loss_distill2 += (1 / float(id_layer + 1)) * loss_bce(k_scores.permute(1, 2, 0).sigmoid()[k_TRAIN_PL[i_num].sum(-1) > 0],k_TRAIN_PL[i_num][k_TRAIN_PL[i_num].sum(-1) > 0])
            loss = loss_supervised + loss_self
            #loss = loss_supervised

            optimizer_DHCN.zero_grad()
            nn.utils.clip_grad_norm_(model_DHCN.parameters(), 3.0)
            loss.backward()
            optimizer_DHCN.step()
            internum = 50
            if epoch < 300:
                internum = 100
            if epoch > 500:
                internum = 10

            if epoch % internum == 0:
                model_DHCN.eval()

                p_idx = []
                fusion_prediction = 0.0

                for k_data, current_data in enumerate(INPUT_DATA[0]):
                    scores, _ = model_DHCN(current_data.to(device))
                    if params.ROT == False:
                        for k_score in scores:
                            fusion_prediction += F.softmax(
                                k_score[0].permute(1, 2, 0),
                                dim=-1).cpu().data.numpy()
                    else:
                        for k_score in scores:
                            if k_data == 0:
                                fusion_prediction += F.softmax(
                                    k_score[0].permute(1, 2, 0),
                                    dim=-1).cpu().data.numpy()
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[1].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=2,
                                    axes=(0, 1))
                                fusion_prediction += F.softmax(
                                    k_score[2].permute(1, 2, 0),
                                    dim=-1).cpu().data.numpy()[::-1, :, :]
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[3].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=2,
                                    axes=(0, 1))[::-1, :, :]

                                p_idx.append(
                                    k_score[0].max(0)[-1].cpu().data.numpy())
                                p_idx.append(
                                    np.rot90(k_score[1].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=2,
                                             axes=(0, 1)))
                                p_idx.append(k_score[2].max(0)
                                             [-1].cpu().data.numpy()[::-1, :])
                                p_idx.append(
                                    np.rot90(k_score[3].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=2,
                                             axes=(0, 1))[::-1, :])

                            if k_data == 1:
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[0].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=-1,
                                    axes=(0, 1))
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[1].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=1,
                                    axes=(0, 1))
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[2].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=-1,
                                    axes=(0, 1))[::-1, :, :]
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[3].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=1,
                                    axes=(0, 1))[::-1, :, :]

                                p_idx.append(
                                    np.rot90(k_score[0].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=-1,
                                             axes=(0, 1)))
                                p_idx.append(
                                    np.rot90(k_score[1].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=1,
                                             axes=(0, 1)))
                                p_idx.append(
                                    np.rot90(k_score[2].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=-1,
                                             axes=(0, 1))[::-1, :])
                                p_idx.append(
                                    np.rot90(k_score[3].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=1,
                                             axes=(0, 1))[::-1, :])

                Acc = np.zeros([len(p_idx) + 1])
                for count, k_idx in enumerate(p_idx):
                    Acc[count] = \
                    metrics(k_idx.reshape(img.shape[:2]), test_gt, ignored_labels=IGNORED_LABELS, n_classes=N_CLASSES)[
                        'Accuracy']
                Acc[-1] = \
                metrics(fusion_prediction.argmax(-1).reshape(img.shape[:2]), test_gt, ignored_labels=IGNORED_LABELS,
                        n_classes=N_CLASSES)['Accuracy']
                OA, AA = calprecision(fusion_prediction.argmax(-1).reshape(
                    img.shape[:2]),
                                      test_gt,
                                      n_classes=N_CLASSES)
                kappa = metrics(fusion_prediction.argmax(-1).reshape(
                    img.shape[:2]),
                                test_gt,
                                ignored_labels=IGNORED_LABELS,
                                n_classes=N_CLASSES)['Kappa']

                tmp_count += 1

                if max(Acc) > best_ACC:
                    best_ACC = max(Acc)
                    save_file_path = save_path + 'save_' + str(
                        epoch) + '_' + str(round(best_ACC, 2)) + '.pth'
                    states = {
                        'state_dict_DHCN': model_DHCN.state_dict(),
                        'train_gt': train_gt,
                        'test_gt': test_gt,
                    }

                    torch.save(states, save_file_path)

                    tmp_count = 0
                    tmp_epoch = epoch
                    print('save: ', epoch, str(round(best_ACC, 2)))
                    print('save: %d, OA: %f AA: %f Kappa: %f' %
                          (epoch, OA, AA, kappa))
                    #print(loss_supervised.data, loss_self.data, loss_distill.data)
                    print(loss_supervised.data)
                    print(np.round(Acc, 2))

                if tmp_count == max_tmp_count:
                    reload_model = True
                    tmp_count = 0
コード例 #9
0
# Load the dataset
img, gt, LABEL_VALUES = get_dataset(DATASET, FOLDER)
# Number of classes
N_CLASSES = len(LABEL_VALUES)
# Number of bands (last dimension of the image tensor)
N_BANDS = img.shape[-1]
# Random seeds
seeds = [7, 17, 27, 37, 47, 57, 67, 77, 87, 97]
# seeds = [7] * 10    #used for finding the best number of groups

results = []
# run the experiment several times
for run in range(N_RUNS):
    np.random.seed(seeds[run])
    # Sample random training spectra
    train_gt, test_gt = sample_gt(gt, PERCENTAGE, seed=seeds[run])
    # Split train set in train/val
    train_gt, val_gt = sample_gt(train_gt, 0.5, seed=seeds[run])
    print("Training samples: {}, validating samples: {}, total samples: {})".
          format(np.sum(train_gt > -1), np.sum(val_gt > -1), np.sum(gt > -1)))
    print("Running an experiment with the {} model".format(MODEL),
          "run {}/{}".format(run + 1, N_RUNS))

    # data
    train_dataset = HyperX(img,
                           train_gt,
                           patch_size=PATCH_SIZE,
                           data_aug=DATA_AUG)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=BATCH_SIZE,
                                   drop_last=False,
コード例 #10
0
ファイル: main.py プロジェクト: ShuoWangCS/HSI-DHL
def main(params):
    RUNS = 4
    MX_ITER = 100
    os.environ["CUDA_VISIBLE_DEVICES"] = params.GPU

    device = torch.device("cuda: 0")
    print torch.cuda.device_count()

    new_path = str(params.DATASET) + '/Best/' + '_'.join([
        str(params.SAMPLE_PERCENTAGE),
        str(params.DHCN_LAYERS),
        str(params.CONV_SIZE),
        str(params.ROT),
        str(params.MIRROR),
        str(params.H_MIRROR)
    ]) + '/'
    if os.path.exists(new_path):
        RUNS = 4 - len(os.listdir(new_path))

    if RUNS == 0:
        return

    for _ in range(RUNS):

        start_time = time.time()

        SAMPLE_PERCENTAGE = params.SAMPLE_PERCENTAGE
        DATASET = params.DATASET
        DHCN_LAYERS = params.DHCN_LAYERS
        CONV_SIZE = params.CONV_SIZE
        H_MIRROR = params.H_MIRROR
        LR = 1e-4

        save_path = str(DATASET) + '/tmp' + str(params.GPU) + '_abc/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        img, gt, LABEL_VALUES, IGNORED_LABELS, _, _ = get_dataset(
            DATASET, "Datasets/")
        N_CLASSES = len(LABEL_VALUES)
        INPUT_SIZE = np.shape(img)[-1]
        train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode='fixed')
        INPUT_DATA = pre_data(img, train_gt, params, gt)

        model_DHCN = DHCN(input_size=INPUT_SIZE,
                          embed_size=INPUT_SIZE,
                          densenet_layer=DHCN_LAYERS,
                          output_size=N_CLASSES,
                          conv_size=CONV_SIZE,
                          batch_norm=False).to(device)
        optimizer_DHCN = torch.optim.Adam(model_DHCN.parameters(),
                                          lr=LR,
                                          betas=(0.5, 0.999),
                                          weight_decay=1e-4)
        model_DHCN = nn.DataParallel(model_DHCN)
        loss_ce = nn.CrossEntropyLoss().to(device)
        loss_bce = nn.BCELoss().to(device)

        best_ACC, tmp_epoch, tmp_count, tmp_rate, recode_reload, reload_model = 0.0, 0, 0, LR, {}, False
        max_tmp_count = 300

        for epoch in range(MX_ITER):

            model_DHCN.train()

            loss_supervised, loss_self, loss_distill = 0.0, 0.0, 0.0

            for TRAIN_IMG, TRAIN_Y, TRAIN_PL in zip(INPUT_DATA[0],
                                                    INPUT_DATA[1],
                                                    INPUT_DATA[2]):
                scores, _ = model_DHCN(TRAIN_IMG.to(device))
                for k_Layer in range(DHCN_LAYERS + 1):
                    for i_num, (k_scores, k_TRAIN_Y) in enumerate(
                            zip(scores[k_Layer], TRAIN_Y)):
                        k_TRAIN_Y = k_TRAIN_Y.to(device)
                        loss_supervised += loss_ce(
                            k_scores.permute(1, 2, 0)[k_TRAIN_Y > 0],
                            k_TRAIN_Y[k_TRAIN_Y > 0])
                        for id_layer, k_TRAIN_PL in enumerate(TRAIN_PL):
                            k_TRAIN_PL = k_TRAIN_PL.to(device)
                            if (k_TRAIN_PL[i_num].sum(-1) > 1).sum() > 0:
                                loss_distill += (
                                    1 / float(id_layer + 1)) * loss_bce(
                                        k_scores.permute(1, 2, 0).sigmoid()[
                                            k_TRAIN_PL[i_num].sum(-1) > 0],
                                        k_TRAIN_PL[i_num][
                                            k_TRAIN_PL[i_num].sum(-1) > 0])
                            else:
                                onehot2label = torch.topk(
                                    k_TRAIN_PL[i_num], k=1,
                                    dim=-1)[1].squeeze(-1)
                                loss_self += (
                                    1 / float(id_layer + 1)) * loss_ce(
                                        k_scores.permute(1, 2,
                                                         0)[onehot2label > 0],
                                        onehot2label[onehot2label > 0])

            loss = loss_supervised + loss_self + loss_distill

            optimizer_DHCN.zero_grad()
            nn.utils.clip_grad_norm_(model_DHCN.parameters(), 3.0)
            loss.backward()
            optimizer_DHCN.step()

            if epoch % 1 == 0:
                model_DHCN.eval()

                p_idx = []
                fusion_prediction = 0.0

                for k_data, current_data in enumerate(INPUT_DATA[0]):
                    scores, _ = model_DHCN(current_data.to(device))
                    if params.ROT == False:
                        for k_score in scores:
                            fusion_prediction += F.softmax(
                                k_score[0].permute(1, 2, 0),
                                dim=-1).cpu().data.numpy()
                    else:
                        for k_score in scores:
                            if k_data == 0:
                                fusion_prediction += F.softmax(
                                    k_score[0].permute(1, 2, 0),
                                    dim=-1).cpu().data.numpy()
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[1].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=2,
                                    axes=(0, 1))
                                fusion_prediction += F.softmax(
                                    k_score[2].permute(1, 2, 0),
                                    dim=-1).cpu().data.numpy()[::-1, :, :]
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[3].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=2,
                                    axes=(0, 1))[::-1, :, :]

                                p_idx.append(
                                    k_score[0].max(0)[-1].cpu().data.numpy())
                                p_idx.append(
                                    np.rot90(k_score[1].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=2,
                                             axes=(0, 1)))
                                p_idx.append(k_score[2].max(0)
                                             [-1].cpu().data.numpy()[::-1, :])
                                p_idx.append(
                                    np.rot90(k_score[3].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=2,
                                             axes=(0, 1))[::-1, :])

                            if k_data == 1:
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[0].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=-1,
                                    axes=(0, 1))
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[1].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=1,
                                    axes=(0, 1))
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[2].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=-1,
                                    axes=(0, 1))[::-1, :, :]
                                fusion_prediction += np.rot90(
                                    F.softmax(k_score[3].permute(1, 2, 0),
                                              dim=-1).cpu().data.numpy(),
                                    k=1,
                                    axes=(0, 1))[::-1, :, :]

                                p_idx.append(
                                    np.rot90(k_score[0].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=-1,
                                             axes=(0, 1)))
                                p_idx.append(
                                    np.rot90(k_score[1].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=1,
                                             axes=(0, 1)))
                                p_idx.append(
                                    np.rot90(k_score[2].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=-1,
                                             axes=(0, 1))[::-1, :])
                                p_idx.append(
                                    np.rot90(k_score[3].max(0)
                                             [-1].cpu().data.numpy(),
                                             k=1,
                                             axes=(0, 1))[::-1, :])

                Acc = np.zeros([len(p_idx) + 1])
                for count, k_idx in enumerate(p_idx):
                    Acc[count] = metrics(k_idx.reshape(img.shape[:2]),
                                         test_gt,
                                         ignored_labels=IGNORED_LABELS,
                                         n_classes=N_CLASSES)['Accuracy']
                Acc[-1] = metrics(fusion_prediction.argmax(-1).reshape(
                    img.shape[:2]),
                                  test_gt,
                                  ignored_labels=IGNORED_LABELS,
                                  n_classes=N_CLASSES)['Accuracy']

                tmp_count += 1

                if max(Acc) > best_ACC:

                    best_ACC = max(Acc)
                    save_file_path = save_path + 'save_' + str(
                        epoch) + '_' + str(round(best_ACC, 2)) + '.pth'
                    states = {
                        'state_dict_DHCN': model_DHCN.state_dict(),
                        'train_gt': train_gt,
                        'test_gt': test_gt,
                    }

                    torch.save(states, save_file_path)

                    tmp_count = 0
                    tmp_epoch = epoch
                    print 'save: ', epoch, str(round(best_ACC, 2))
                    print loss_supervised.data, loss_self.data, loss_distill.data
                    print np.round(Acc, 2)
コード例 #11
0
def train_model(img, gt, hyperparams):
    """
    Function for model training.
    1) Data sampling into a training, a validation and a test set.
    2) Training a chosen model.
    3) Model evaluation.

    Arguments:
    img - dataset (hyperspectral image)
    gt - ground truth (labels)
    hyperparams - parameters of training
    SVM_GRID_PARAMS - parameters for SVM (if used)
    FOLDER - a path for datasets
    DATASET - name of the used dataset 
    set_parameters: option for loading a specific training and test set
    preprocessing_parameters: parameters of preprocessing
    """
    print("img.shape: {}".format(img.shape))
    print("gt.shape: {}".format(gt.shape))

    # all images should have 113 bands
    assert(img.shape[2] == 113)

    viz = None
    results = []
    # run the experiment several times
    for run in range(hyperparams['runs']):
        #############################################################################
        # Create a training and a test set
        if hyperparams['train_gt'] is not None and hyperparams['test_gt'] is not None:
            train_gt = open_file(hyperparams['train_gt'])
            test_gt = open_file(hyperparams['test_gt'])
        elif hyperparams['train_gt'] is not None:
            train_gt = open_file(hyperparams['train_gt'])
            test_gt = np.copy(gt)
            w, h = test_gt.shape
            test_gt[(train_gt > 0)[:w, :h]] = 0
        elif hyperparams['test_gt'] is not None:
            test_gt = open_file(hyperparams['test_gt'])
        else:
            # Choose type of data sampling
            if hyperparams['sampling_mode'] == 'uniform':
                train_gt, test_gt = select_subset(gt, hyperparams['training_sample'])
                check_split_correctness(gt, train_gt, test_gt, hyperparams['n_classes'])
            elif hyperparams['sampling_mode'] == 'fixed':
                # load fixed sets from a given path
                train_gt, test_gt = get_fixed_sets(run, hyperparams['sample_path'], hyperparams['dataset'])
                check_split_correctness(gt, train_gt, test_gt, hyperparams['n_classes'], 'fixed')
            else:
                train_gt, test_gt = sample_gt(gt,
                                              hyperparams['training_sample'],
                                              mode=hyperparams['sampling_mode'])
            
        print("{} samples selected (over {})".format(np.count_nonzero(train_gt),
                                                     np.count_nonzero(gt)))
        print("Running an experiment with the {} model".format(hyperparams['model']),
              "run {}/{}".format(run + 1, hyperparams['runs']))
        #######################################################################
        # Train a model

        if hyperparams['model'] == 'SVM_grid':
            print("Running a grid search SVM")
            # Grid search SVM (linear and RBF)
            X_train, y_train = build_dataset(img, train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.svm.SVC(class_weight=class_weight)
            clf = sklearn.model_selection.GridSearchCV(clf,
                                                       hyperparams['svm_grid_params'],
                                                       verbose=5,
                                                       n_jobs=4)
            clf.fit(X_train, y_train)
            print("SVM best parameters : {}".format(clf.best_params_))
            prediction = clf.predict(img.reshape(-1, hyperparams['n_bands']))
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = prediction.reshape(img.shape[:2])
        elif hyperparams['model'] == 'SVM':
            X_train, y_train = build_dataset(img, train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.svm.SVC(class_weight=class_weight)
            clf.fit(X_train, y_train)
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = clf.predict(img.reshape(-1, hyperparams['n_bands']))
            prediction = prediction.reshape(img.shape[:2])
        elif hyperparams['model'] == 'SGD':
            X_train, y_train = build_dataset(img, train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
            scaler = sklearn.preprocessing.StandardScaler()
            X_train = scaler.fit_transform(X_train)
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.linear_model.SGDClassifier(class_weight=class_weight,
                                                     learning_rate='optimal',
                                                     tol=1e-3,
                                                     average=10)
            clf.fit(X_train, y_train)
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = clf.predict(scaler.transform(img.reshape(-1,
                                                      hyperparams['n_bands'])))
            prediction = prediction.reshape(img.shape[:2])
        elif hyperparams['model'] == 'nearest':
            X_train, y_train = build_dataset(img,
                                             train_gt,
                                             ignored_labels=hyperparams['ignored_labels'])
            X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
            class_weight = 'balanced' if hyperparams['class_balancing'] else None
            clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
            clf = sklearn.model_selection.GridSearchCV(clf,
                                                       {'n_neighbors': [1, 3, 5, 10, 20]},
                                                       verbose=5,
                                                       n_jobs=4)
            clf.fit(X_train, y_train)
            clf.fit(X_train, y_train)
            save_model(clf,
                       hyperparams['model'],
                       hyperparams['dataset'],
                       hyperparams['rdir'])
            prediction = clf.predict(img.reshape(-1, hyperparams['n_bands']))
            prediction = prediction.reshape(img.shape[:2])
        else:
            # Neural network
            model, optimizer, loss, hyperparams = get_model(hyperparams['model'], **hyperparams)
            if hyperparams['class_balancing']:
                weights = compute_imf_weights(train_gt,
                                              hyperparams['n_classes'],
                                              hyperparams['ignored_labels'])
                hyperparams['weights'] = torch.from_numpy(weights)
            # Split train set in train/val
            if hyperparams['sampling_mode'] in {'uniform', 'fixed'}:
                train_gt, val_gt = select_subset(train_gt, 0.95)
            else:
                train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
            # Generate the dataset
            train_dataset = HyperX(img, train_gt, **hyperparams)
            train_loader = data.DataLoader(train_dataset,
                                           batch_size=hyperparams['batch_size'],
                                           shuffle=True)
            val_dataset = HyperX(img, val_gt, **hyperparams)
            val_loader = data.DataLoader(val_dataset,
                                         batch_size=hyperparams['batch_size'])

            print(hyperparams)
            print("Network :")
            with torch.no_grad():
                for input, _ in train_loader:
                    break
                summary(model.to(hyperparams['device']), input.size()[1:])
                # We would like to use device=hyperparams['device'] altough we have
                # to wait for torchsummary to be fixed first.

            if hyperparams['checkpoint'] is not None:
                model.load_state_dict(torch.load(hyperparams['checkpoint']))

            try:
                train(model,
                      optimizer,
                      loss,
                      train_loader,
                      hyperparams['epoch'],
                      scheduler=hyperparams['scheduler'],
                      device=hyperparams['device'],
                      supervision=hyperparams['supervision'],
                      val_loader=val_loader,
                      display=viz,
                      rdir=hyperparams['rdir'],
                      model_name=hyperparams['model'],
                      preprocessing=hyperparams['preprocessing']['type'],
                      run=run)
            except KeyboardInterrupt:
                # Allow the user to stop the training
                pass

            probabilities = test(model, img, hyperparams)
            prediction = np.argmax(probabilities, axis=-1)

        #######################################################################
        # Evaluate the model
        # If test set is not empty
        if(np.unique(test_gt).shape[0] > 1):
            run_results = metrics(prediction,
                                  test_gt,
                                  ignored_labels=hyperparams['ignored_labels'],
                                  n_classes=hyperparams['n_classes'])

        mask = np.zeros(gt.shape, dtype='bool')
        for l in hyperparams['ignored_labels']:
            mask[gt == l] = True
        prediction[mask] = 0