def pretrain_lstm(epoch_nb, encoder, decoder, loader, args):
    learning_rate = args.learning_rate
    optimizer = torch.optim.Adam((list(encoder.parameters()) + list(decoder.parameters())), lr=learning_rate)
    # optimizer = torch.optim.SGD((list(encoder.parameters()) + list(decoder.parameters())), lr=args.learning_rate, momentum=0.5)
    # print_stats(args.stats_file, "Optimiser SGD")
    encoder.train()
    decoder.train()
    epoch_loss_list = []
    for epoch in range(epoch_nb):
        learning_rate = adjust_learning_rate(optimizer, epoch+1, learning_rate)
        print("learning rate = " + str(learning_rate))

        total_loss = 0
        for batch_idx, (data, _, id) in enumerate(loader):
            if gpu:
                data = data.cuda()
            # we delete zero-padding so the whole batch have the sequence lenght equal to the longest sequence in this batch
            # initially the sequence lenght is equal to the lenght of SITS
            idx = [i for i in range(data.size(1) - 1, -1, -1)]
            idx = torch.LongTensor(idx).cuda()
            inverted_data = torch.index_select(data, 1, idx)
            # inverted_data = np.flip(data, axis=1)
            encoded_output = encoder(Variable(data))
            decoded = decoder(encoded_output)
            loss = criterion1(decoded, Variable(inverted_data))
            loss_data = loss.item()
            total_loss += loss_data
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch_idx+1) % 200 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.7f}'.format(
                    (epoch+1), (batch_idx+1) * args.batch_size, len(loader)*args.batch_size,
                    100. * (batch_idx+1) / len(loader), loss_data))
        epoch_loss = total_loss / len(loader)
        epoch_loss_list.append(epoch_loss)
        epoch_stats = "Pretraining Epoch {} Complete: Avg. Loss: {:.7f}".format(epoch + 1, epoch_loss)
        print_stats(args.stats_file, epoch_stats)
        if (epoch) % 5 == 0:
            plotting(epoch+1, np.asarray(epoch_loss_list), args.path_results)
            torch.save([encoder, decoder], (args.path_model + 'ae-model_ep_' + str(epoch + 1) + "_loss_" + str(
                round(epoch_loss, 7)) + args.run_name + '.pkl'))
    try:
        plotting(epoch + 1, np.asarray(epoch_loss_list), args.path_results)
    except UnboundLocalError:
        pass
    try:
        torch.save([encoder, decoder], (args.path_model + 'ae-model_ep_' + str(epoch + 1) + "_loss_" + str(round(epoch_loss, 7)) + args.run_name + '.pkl'))
    except:
        pass
示例#2
0
def pretrain(epoch_nb, encoder, decoder, loader, args, v=None, lock=None):
    #optimizer = torch.optim.Adam((list(encoder.parameters()) + list(decoder.parameters())), lr=args.learning_rate)
    optimizer = torch.optim.Adam((list(encoder.parameters()) + list(decoder.parameters())), lr=args.learning_rate)
    # print_stats(args.stats_file, "Optimizer SGD")
    for epoch in range(epoch_nb):
        # epoch_loss_list = []
        encoder.train()
        decoder.train()
        total_loss = 0
        total_loss_or = 0
        total_loss_ndvi = 0
        for batch_idx, (data_or, data_ndvi, id) in enumerate(loader):
            if gpu:
                data_or = data_or.cuda()
                data_ndvi = data_ndvi.cuda()
            encoded, id1 = encoder(Variable(data_or), Variable(data_ndvi))
            decoded_or, decoded_ndvi = decoder(encoded, id1)
            loss_or = criterion1(decoded_or, Variable(data_or))
            loss_ndvi = criterion1(decoded_ndvi, Variable(data_ndvi))
            loss = (loss_or + loss_ndvi)/2
            loss_data_or = loss_or.item()
            loss_data_ndvi = loss_ndvi.item()
            loss_data = (loss_data_or+loss_data_ndvi)/2
            total_loss += loss_data
            total_loss_or += loss_data_or
            total_loss_ndvi += loss_data_ndvi
            optimizer.zero_grad()
            # loss_or.backward(retain_graph=True)
            # loss_ndvi.backward()
            loss.backward()
            optimizer.step()
            if (batch_idx+1) % 200 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.7f}\tLoss_or: {:.7f}\tLoss_ndvi: {:.7f}'.format(
                    (epoch+1), (batch_idx+1) * args.batch_size, len(loader)*args.batch_size,
                    100. * (batch_idx+1) / len(loader), loss_data, loss_data_or, loss_data_ndvi))
        epoch_loss = total_loss / len(loader)
        epoch_loss_or = total_loss_or / len(loader)
        epoch_loss_ndvi = total_loss_ndvi / len(loader)
        # epoch_loss_list.append(epoch_loss)
        epoch_stats = "Pretraining Epoch {} Complete: Avg. Loss: {:.7f}, Avg. Loss_or: {:.7f}, Avg. Loss_ndvi: {:.7f}".format(epoch + 1, epoch_loss, epoch_loss_or, epoch_loss_ndvi)
        print_stats(args.stats_file, epoch_stats)
        torch.save([encoder, decoder], (args.path_model+'ae-model_ep_'+str(epoch+1)+"_loss_"+str(round(epoch_loss, 7))+args.run_name+'.pkl') )
示例#3
0
def main():
    gpu = on_gpu()
    print("ON GPU is " + str(gpu))

    #Parameters
    parser = argparse.ArgumentParser(description='train')
    parser.add_argument('--satellite',
                        default="SPOT5",
                        type=str,
                        help="choose from SPOT5 and S2")
    parser.add_argument('--patch_size', default=9, type=int)
    parser.add_argument('--patch_size_ndvi', default=5, type=int)
    parser.add_argument('--nb_features',
                        default=10,
                        type=int,
                        help="f parameter from the article")
    parser.add_argument('--batch_size', default=150, type=int)
    parser.add_argument(
        '--bands_to_keep',
        default=4,
        type=int,
        help=
        'whether we delete swir band for spot-5 or blue for S2, defauld - all 4 bands'
    )
    parser.add_argument('--epoch_nb', default=2, type=int)
    parser.add_argument('--learning_rate', default=0.0001, type=float)
    parser.add_argument('--noise_factor',
                        default=0.25,
                        type=float,
                        help='for denoising AE, original images')
    parser.add_argument('--noise_factor_ndvi',
                        default=None,
                        type=float,
                        help='for denoising AE, NDVI branch')
    parser.add_argument(
        '--centered',
        default=True,
        type=bool,
        help='whether we center data with mean and std before training')
    parser.add_argument(
        '--original_layers',
        default=[32, 32, 64, 64],
        type=list,
        help='Nb of conv. layers to build AE')  #Default article model
    parser.add_argument(
        '--ndvi_layers',
        default=[16, 16, True],
        type=list,
        help='Nb of conv. layers to build AE and pooling option'
    )  #Default article model
    args = parser.parse_args()

    start_time = time.time()
    run_name = "." + str(time.strftime("%Y-%m-%d_%H%M%S"))
    print(run_name)

    # We define all the paths
    path_results_final = os.path.expanduser('~/Desktop/Results/TS_clustering/')

    if args.satellite == "SPOT5":
        path_datasets = os.path.expanduser(
            '~/Desktop/Datasets/Montpellier_SPOT5_Clipped_relatively_normalized_03_02_mask_vegetation_water_mode_parts_2004_no_DOS1_/'
        )
        path_datasets_ndvi = os.path.expanduser(
            '~/Desktop/Results/TS_clustering/NDVI_results/NDVI_images/')
        folder_results = "Double_Trivial_feat_" + str(
            args.nb_features) + "_patch_" + str(args.patch_size) + run_name
        path_results = path_results_final + "Conv_3D/" + folder_results + "/"

    else:
        path_datasets = os.path.expanduser(
            '~/Desktop/Datasets/Montpellier_S2_Concatenated_1C_Clipped_norm_4096/'
        )
        path_datasets_ndvi = os.path.expanduser(
            '~/Desktop/Results/TS_clustering/NDVI_results/NDVI_images_S2/')
        folder_results = "Double_Trivial_feat_" + str(
            args.nb_features) + "_patch_" + str(args.patch_size) + run_name
        path_results = path_results_final + "Conv_3D_S2/" + folder_results + "/"

    create_dir(path_results)
    stats_file = path_results + 'stats.txt'
    path_model = path_results + 'model' + run_name + "/"
    create_dir(path_model)

    print_stats(stats_file, str(args), print_to_console=True)
    parser.add_argument('--stats_file', default=stats_file)
    parser.add_argument('--path_results', default=path_results)
    parser.add_argument('--path_model', default=path_model)
    parser.add_argument('--run_name', default=run_name)
    args = parser.parse_args()

    # This part of the code opens and pre-processes the images before creating a dataset
    # This is the part for original images, i am lazy, so i will copy-paste it for ndvi images below
    #We open extended images
    images_list = os.listdir(path_datasets)
    path_list = []
    list_image_extended = []
    list_image_date = []
    for image_name_with_extention in images_list:
        if image_name_with_extention.endswith(
                ".TIF") and not image_name_with_extention.endswith("band.TIF"):
            img_path = path_datasets + image_name_with_extention
            if args.satellite == "SPOT5":
                image_date = (re.search("_([0-9]*)_",
                                        image_name_with_extention)).group(1)
            else:
                image_date = (re.search("S2_([0-9]*).",
                                        image_name_with_extention)).group(1)

            path_list.append(img_path)
            image_array, H, W, geo, proj, bands_nb = open_tiff(
                path_datasets,
                os.path.splitext(image_name_with_extention)[0])
            if args.bands_to_keep == 3:
                if args.satellite == "SPOT5":
                    image_array = np.delete(image_array, 3, axis=0)
                if args.satellite == "S2":
                    image_array = np.delete(image_array, 0, axis=0)
            # We deal with all the saturated pixels
            if args.satellite == "S2":
                for b in range(len(image_array)):
                    image_array[b][image_array[b] > 4096] = np.max(
                        image_array[b][image_array[b] <= 4096])
            if args.satellite == "SPOT5":
                for b in range(len(image_array)):
                    image_array[b][image_array[b] > 475] = np.max(
                        image_array[b][image_array[b] <= 475])
            bands_nb = args.bands_to_keep
            image_extended = extend(
                image_array, args.patch_size
            )  # we mirror image border rows and columns so we would be able to clip patches for the pixels from these rows and cols
            list_image_extended.append(image_extended)
            list_image_date.append(image_date)
    sort_ind = np.argsort(
        list_image_date)  # we arrange images by date of acquisition
    list_image_extended = np.asarray(list_image_extended,
                                     dtype=float)[sort_ind]
    bands_nb = list_image_extended.shape[1]
    temporal_dim = list_image_extended.shape[0]
    list_image_date = np.asarray(list_image_date)[sort_ind]
    nbr_images = len(list_image_extended)
    print(list_image_date)

    if args.centered is True:
        list_norm = []
        for band in range(len(list_image_extended[0])):
            all_images_band = list_image_extended[:, band, :, :].flatten()
            min = np.min(all_images_band)
            max = np.max(all_images_band)
            mean = np.mean(all_images_band)
            std = np.std(all_images_band)
            list_norm.append([min, max, mean, std])

        for i in range(len(list_image_extended)):
            for band in range(len(list_image_extended[0])):
                list_image_extended[i][band] = (
                    list_image_extended[i][band] -
                    list_norm[band][2]) / list_norm[band][3]

    list_norm = []
    for band in range(len(list_image_extended[0])):
        all_images_band = list_image_extended[:, band, :, :].flatten()
        min = np.min(all_images_band)
        max = np.max(all_images_band)
        list_norm.append([min, max])

    for i in range(len(list_image_extended)):
        for band in range(len(list_image_extended[0])):
            list_image_extended[i][band] = (
                list_image_extended[i][band] -
                list_norm[band][0]) / (list_norm[band][1] - list_norm[band][0])

    list_norm = []
    for band in range(len(list_image_extended[0])):
        all_images_band = list_image_extended[:, band, :, :].flatten()
        mean = np.mean(all_images_band)
        std = np.std(all_images_band)
        list_norm.append([mean, std])

    #We do exactly the same with NDVI images. I was lasy to create a separate function for this
    images_list_ndvi = os.listdir(path_datasets_ndvi)
    path_list_ndvi = []
    list_image_extended_ndvi = []
    list_image_date_ndvi = []
    for image_name_with_extention_ndvi in images_list_ndvi:
        if image_name_with_extention_ndvi.endswith(
                ".TIF") and image_name_with_extention_ndvi.startswith("NDVI_"):
            img_path_ndvi = path_datasets_ndvi + image_name_with_extention_ndvi
            # print(img_path_ndvi)
            image_date_ndvi = (re.search(
                "_([0-9]*).", image_name_with_extention_ndvi)).group(1)
            # print(image_date_ndvi)
            # print_stats(stats_file, str(image_date), print_to_console=True)
            path_list_ndvi.append(img_path_ndvi)
            image_array_ndvi, H, W, geo, proj, _ = open_tiff(
                path_datasets_ndvi,
                os.path.splitext(image_name_with_extention_ndvi)[0])
            image_array_ndvi = np.reshape(image_array_ndvi, (1, H, W))
            image_extended_ndvi = extend(image_array_ndvi,
                                         args.patch_size_ndvi)
            list_image_extended_ndvi.append(image_extended_ndvi)
            list_image_date_ndvi.append(image_date_ndvi)
    sort_ind_ndvi = np.argsort(
        list_image_date_ndvi)  # we arrange images by date of acquisition
    list_image_extended_ndvi = np.asarray(list_image_extended_ndvi,
                                          dtype=float)[sort_ind_ndvi]
    list_image_date_ndvi = np.asarray(list_image_date_ndvi)[sort_ind_ndvi]
    print(list_image_date_ndvi)

    if args.centered is True:
        list_norm_ndvi = []
        for band in range(len(list_image_extended_ndvi[0])):
            all_images_band = list_image_extended_ndvi[:, band, :, :].flatten()
            min = np.min(all_images_band)
            max = np.max(all_images_band)
            mean = np.mean(all_images_band)
            std = np.std(all_images_band)
            list_norm_ndvi.append([min, max, mean, std])

        for i in range(len(list_image_extended_ndvi)):
            for band in range(len(list_image_extended_ndvi[0])):
                list_image_extended_ndvi[i][band] = (
                    list_image_extended_ndvi[i][band] -
                    list_norm_ndvi[band][2]) / list_norm_ndvi[band][3]

    list_norm_ndvi = []
    for band in range(len(list_image_extended_ndvi[0])):
        all_images_band = list_image_extended_ndvi[:, band, :, :].flatten()
        min = np.min(all_images_band)
        max = np.max(all_images_band)
        list_norm_ndvi.append([min, max])

    for i in range(len(list_image_extended_ndvi)):
        for band in range(len(list_image_extended_ndvi[0])):
            list_image_extended_ndvi[i][band] = (
                list_image_extended_ndvi[i][band] - list_norm_ndvi[band][0]
            ) / (list_norm_ndvi[band][1] - list_norm_ndvi[band][0])

    list_norm_ndvi = []
    for band in range(len(list_image_extended_ndvi[0])):
        all_images_band = list_image_extended_ndvi[:, band, :, :].flatten()
        mean = np.mean(all_images_band)
        std = np.std(all_images_band)
        list_norm_ndvi.append([mean, std])

    # We create a training dataset from our SITS
    list_image_extended_tr = np.transpose(list_image_extended, (1, 0, 2, 3))
    list_image_extended_ndvi_tr = np.transpose(list_image_extended_ndvi,
                                               (1, 0, 2, 3))
    nbr_patches_per_image = H * W  # Nbr of training patches for the dataset
    print_stats(stats_file,
                "Nbr of training patches  " + str(nbr_patches_per_image),
                print_to_console=True)
    image = ImageDataset(
        list_image_extended_tr,
        list_image_extended_ndvi_tr, args.patch_size, args.patch_size_ndvi,
        range(nbr_patches_per_image))  #we create a dataset with tensor patches
    loader_pretrain = dsloader(image, gpu, args.batch_size, shuffle=True)
    image = None

    # We create encoder and decoder models
    if args.noise_factor is not None:
        encoder = Encoder(bands_nb, args.patch_size, args.patch_size_ndvi,
                          args.nb_features, temporal_dim, args.original_layers,
                          args.ndvi_layers, np.asarray(list_norm),
                          np.asarray(list_norm_ndvi), args.noise_factor,
                          args.noise_factor_ndvi)  # On CPU
    else:
        encoder = Encoder(bands_nb, args.patch_size, args.patch_size_ndvi,
                          args.nb_features, temporal_dim, args.original_layers,
                          args.ndvi_layers)  # On CPU
    decoder = Decoder(bands_nb, args.patch_size, args.patch_size_ndvi,
                      args.nb_features, temporal_dim, args.original_layers,
                      args.ndvi_layers)  # On CPU
    if gpu:
        encoder = encoder.cuda()  # On GPU
        decoder = decoder.cuda()  # On GPU

    print_stats(stats_file, str(encoder), print_to_console=False)

    # We pretrain the model
    pretrain(args.epoch_nb, encoder, decoder, loader_pretrain, args)
    end_time = time.time()
    total_time_pretraining = end_time - start_time
    total_time_pretraining = str(
        datetime.timedelta(seconds=total_time_pretraining))
    print_stats(
        args.stats_file,
        "Total time pretraining =" + str(total_time_pretraining) + "\n")

    # We pass to the encoding part
    start_time = time.time()
    # We create a dataset for SITS encoding, its size depends on the available memory
    image = None
    loader_pretrain = None
    image = ImageDataset(list_image_extended_tr, list_image_extended_ndvi_tr,
                         args.patch_size, args.patch_size_ndvi, range(
                             H * W))  # we create a dataset with tensor patches
    try:
        batch_size = W
        loader_enc_final = dsloader(image,
                                    gpu,
                                    batch_size=batch_size,
                                    shuffle=False)
    except RuntimeError:
        try:
            batch_size = int(W / 5)
            loader_enc_final = dsloader(image,
                                        gpu,
                                        batch_size=batch_size,
                                        shuffle=False)
        except RuntimeError:
            batch_size = int(W / 20)
            loader_enc_final = dsloader(image,
                                        gpu,
                                        batch_size=batch_size,
                                        shuffle=False)
    image = None

    print_stats(stats_file, 'Encoding...')
    encoded_array = encoding(encoder, loader_enc_final, batch_size)

    # We stretch encoded images between 0 and 255
    encoded_norm = []
    for band in range(args.nb_features):
        min = np.min(encoded_array[:, band])
        max = np.max(encoded_array[:, band])
        encoded_norm.append([min, max])
    for band in range(args.nb_features):
        encoded_array[:, band] = 255 * (
            encoded_array[:, band] - encoded_norm[band][0]) / (
                encoded_norm[band][1] - encoded_norm[band][0])
    print(encoded_array.shape)

    # We write the image
    new_encoded_array = np.transpose(encoded_array, (1, 0))
    ds = create_tiff(
        encoded_array.shape[-1], args.path_results + "Encoded_3D_conv_" +
        str(encoded_array.shape[-1]) + ".TIF", W, H, gdal.GDT_Int16,
        np.reshape(new_encoded_array,
                   (encoded_array.shape[-1], H, W)), geo, proj)
    ds.GetRasterBand(1).SetNoDataValue(-9999)
    ds = None

    end_time = time.time()
    total_time_pretraining = end_time - start_time
    total_time_pretraining = str(
        datetime.timedelta(seconds=total_time_pretraining))
    print_stats(stats_file,
                "Total time encoding =" + str(total_time_pretraining) + "\n")
def main():
    gpu = on_gpu()
    print("ON GPU is " + str(gpu))

    start_time = time.time()
    run_name = "." + str(time.strftime("%Y-%m-%d_%H%M"))
    print(run_name)

    #Parameters
    parser = argparse.ArgumentParser(description='train')
    parser.add_argument('--patch_size', default=9, type=int)
    parser.add_argument('--nb_features', default=5, type=int)
    parser.add_argument('--batch_size', default=150, type=int)
    parser.add_argument('--bands_to_keep', default=4, type=int)
    parser.add_argument('--epoch_nb', default=4, type=int)
    parser.add_argument('--satellite', default="SPOT5", type=str)
    parser.add_argument('--learning_rate', default=0.0001, type=float)
    args = parser.parse_args()

    # path with images to encode
    path_datasets = os.path.expanduser(
        '~/Desktop/Datasets/Montpellier_SPOT5_Clipped_relatively_normalized_03_02_mask_vegetation_water_mode_parts_2004_no_DOS1_/'
    )
    # folder and path to results
    folder_results = "All_images_ep_" + str(args.epoch_nb) + "_patch_" + str(
        args.patch_size) + "_batch_" + str(args.batch_size) + "_feat_" + str(
            args.nb_features) + "_lr_" + str(
                args.learning_rate) + run_name + "_noise1"
    path_results = os.path.expanduser(
        '~/Desktop/Results/Encode_TS_noise/') + folder_results + "/"
    create_dir(path_results)
    # folder with AE models
    path_model = path_results + 'model' + run_name + "/"
    create_dir(path_model)
    # file with corresponding statistics
    stats_file = path_results + 'stats.txt'

    print_stats(stats_file, str(args), print_to_console=False)
    parser.add_argument('--stats_file', default=stats_file)
    parser.add_argument('--path_results', default=path_results)
    parser.add_argument('--path_model', default=path_model)
    parser.add_argument('--run_name', default=run_name)
    args = parser.parse_args()

    #We open images and "extend" them (we mirror border rows and columns for correct patch extraction)
    images_list = os.listdir(path_datasets)
    path_list = []
    list_image_extended = []
    list_image_date = []
    for image_name_with_extention in images_list:
        if image_name_with_extention.endswith(
                ".TIF") and not image_name_with_extention.endswith("band.TIF"):
            img_path = path_datasets + image_name_with_extention
            path_list.append(img_path)
            image_date = (re.search("_([0-9]*)_",
                                    image_name_with_extention)).group(1)
            # we open images
            image_array, H, W, geo, proj, bands_nb = open_tiff(
                path_datasets,
                os.path.splitext(image_name_with_extention)[0])
            # we delete swir bands for spot-5 or blue for Sentinel-2 if needed
            if args.bands_to_keep == 3:
                if args.satellite == "SPOT5":
                    image_array = np.delete(image_array, 3, axis=0)
                else:
                    image_array = np.delete(image_array, 0, axis=0)
            bands_nb = args.bands_to_keep
            # we extend image
            image_extended = extend(image_array, args.patch_size)
            list_image_extended.append(image_extended)
            list_image_date.append(image_date)
    sort_ind = np.argsort(
        list_image_date)  # we arrange images by date of acquisition
    list_image_extended = np.asarray(list_image_extended,
                                     dtype=float)[sort_ind]
    list_image_date = np.asarray(list_image_date)[sort_ind]

    # We normalize all the images with dataset mean and std
    list_norm = []
    for band in range(len(list_image_extended[0])):
        all_images_band = list_image_extended[:, band, :, :].flatten()
        min = np.min(all_images_band)
        max = np.max(all_images_band)
        mean = np.mean(all_images_band)
        std = np.std(all_images_band)
        list_norm.append([min, max, mean, std])

    for i in range(len(list_image_extended)):
        for band in range(len(list_image_extended[0])):
            list_image_extended[i][band] = (
                list_image_extended[i][band] -
                list_norm[band][2]) / list_norm[band][3]

    # We rescale from 0 to 1
    list_norm = []
    for band in range(len(list_image_extended[0])):
        all_images_band = list_image_extended[:, band, :, :].flatten()
        min = np.min(all_images_band)
        max = np.max(all_images_band)
        mean = np.mean(all_images_band)
        std = np.std(all_images_band)
        list_norm.append([min, max, mean, std])

    for i in range(len(list_image_extended)):
        for band in range(len(list_image_extended[0])):
            list_image_extended[i][band] = (
                list_image_extended[i][band] -
                list_norm[band][0]) / (list_norm[band][1] - list_norm[band][0])

    # We recompute mean and std to use them for creation of Gaussian noise later
    list_norm = []
    for band in range(len(list_image_extended[0])):
        all_images_band = list_image_extended[:, band, :, :].flatten()
        mean = np.mean(all_images_band)
        std = np.std(all_images_band)
        list_norm.append([mean, std])

    # We create training and validation datasets with H*W/(SITS_length)*2 patches by concatenating datasets created for every image
    image = None
    image_valid = None
    nbr_patches_per_image = int(H * W / len(list_image_extended) * 2)
    # nbr_patches_per_image = H * W
    for ii in range(len(list_image_extended)):
        samples_list = np.sort(sample(range(H * W), nbr_patches_per_image))
        samples_list_valid = np.sort(
            sample(range(H * W), int(nbr_patches_per_image / 100)))
        if image is None:
            image = ImageDataset(
                list_image_extended[ii], args.patch_size, ii,
                samples_list)  # we create a dataset with tensor patches
            image_valid = ImageDataset(
                list_image_extended[ii], args.patch_size, ii,
                samples_list_valid)  # we create a dataset with tensor patches
        else:
            image2 = ImageDataset(
                list_image_extended[ii], args.patch_size, ii,
                samples_list)  # we create a dataset with tensor patches
            image = torch.utils.data.ConcatDataset([image, image2])
            image_valid2 = ImageDataset(
                list_image_extended[ii], args.patch_size, ii,
                samples_list_valid)  # we create a dataset with tensor patches
            image_valid = torch.utils.data.ConcatDataset(
                [image_valid, image_valid2])

    loader = dsloader(image, gpu, args.batch_size, shuffle=True)
    loader_valid = dsloader(image_valid, gpu, H, shuffle=False)

    # we create AE model
    encoder = Encoder(bands_nb, args.patch_size, args.nb_features,
                      np.asarray(list_norm))  # On CPU
    decoder = Decoder(bands_nb, args.patch_size, args.nb_features)  # On CPU
    if gpu:
        encoder = encoder.cuda()  # On GPU
        decoder = decoder.cuda()  # On GPU

    optimizer_encoder = torch.optim.Adam(encoder.parameters(),
                                         lr=args.learning_rate)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(),
                                         lr=args.learning_rate)

    criterion = nn.MSELoss()

    with open(path_results + "stats.txt", 'a') as f:
        f.write(str(encoder) + "\n")
    f.close()

    # Here we deploy early stopping algorithm taken from https://github.com/Bjarten/early-stopping-pytorch
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    early_stopping = EarlyStopping(patience=1, verbose=True)

    # we train the model
    def train(epoch):
        encoder.train()
        decoder.train()
        train_loss_total = 0
        for batch_idx, (data, _, _) in enumerate(loader):
            if gpu:
                data = data.cuda()
            encoded, id1 = encoder(Variable(data))
            decoded = decoder(encoded, id1)
            loss = criterion(decoded, Variable(data))
            train_loss_total += loss.item()
            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()
            loss.backward()
            optimizer_encoder.step()
            optimizer_decoder.step()
            if (batch_idx + 1) % 200 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    (epoch), (batch_idx + 1) * args.batch_size,
                    len(samples_list) * len(list_image_extended),
                    100. * (batch_idx + 1) / len(loader), loss.item()))
        train_loss_total = train_loss_total / len(loader)
        epoch_stats = "Epoch {} Complete: Avg. Loss: {:.7f}".format(
            epoch, train_loss_total)
        print(epoch_stats)
        with open(path_results + "stats.txt", 'a') as f:
            f.write(epoch_stats + "\n")
        f.close()

        # We save trained model after each epoch. Optional
        torch.save([encoder, decoder],
                   (path_model + 'ae-model_ep_' + str(epoch + 1) + "_loss_" +
                    str(round(train_loss_total, 5)) + run_name + '.pkl'))
        torch.save(
            [encoder.state_dict(), decoder.state_dict()],
            (path_model + 'ae-dict_ep_' + str(epoch + 1) + "_loss_" +
             str(round(train_loss_total, 5)) + run_name + '.pkl'))

        #Validation part
        valid_loss_total = 0
        encoder.eval()
        decoder.eval()  # prep model for evaluation
        for batch_idx, (data, _, _) in enumerate(loader_valid):
            if gpu:
                data = data.cuda()
            # forward pass: compute predicted outputs by passing inputs to the model
            encoded, id1 = encoder(Variable(data))
            decoded = decoder(encoded, id1)
            # calculate the loss
            loss = criterion(decoded, Variable(data))
            # record validation loss
            valid_loss_total += loss.item()

        valid_loss_total = valid_loss_total / len(loader_valid)

        avg_train_losses.append(train_loss_total)
        avg_valid_losses.append(valid_loss_total)

        epoch_len = len(str(args.epoch_nb))

        print_msg = (f'[{epoch:>{epoch_len}}/{args.epoch_nb:>{epoch_len}}] ' +
                     f'train_loss: {train_loss_total:.5f} ' +
                     f'valid_loss: {valid_loss_total:.5f}')
        print(print_msg)

        # We plot the loss
        if (epoch + 1) % 5 == 0:
            plotting(epoch, avg_train_losses, path_results)

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss_total, [encoder, decoder])

    for epoch in range(1, args.epoch_nb + 1):
        train(epoch)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    end_time_learning = time.clock()
    total_time_learning = end_time_learning - start_time
    total_time_learning = str(datetime.timedelta(seconds=total_time_learning))
    print_stats(args.stats_file,
                "Total time pretraining =" + str(total_time_learning) + "\n")

    # We get the best model (here by default it is the last one)
    best_epoch = epoch
    best_epoch_loss = avg_train_losses[best_epoch - 1]
    print("best epoch " + str(best_epoch))
    print("best epoch loss " + str(best_epoch_loss))
    best_encoder = encoder
    if gpu:
        best_encoder = best_encoder.cuda()  # On GPU

    #ENCODING PART
    for ii in range(len(list_image_extended)):
        print("Encoding " + str(list_image_date[ii]))
        samples_list = np.array(range(H * W))
        image_encode = ImageDataset(
            list_image_extended[ii], args.patch_size, ii,
            samples_list)  # we create a dataset with tensor patches

        loader_encode = dsloader(image_encode, gpu, H, shuffle=False)

        name_results = list_image_date[ii]
        encode_image(best_encoder, loader_encode, H * 10, args.nb_features,
                     gpu, H, W, geo, proj, name_results, path_results)

    end_time_encoding = time.time()
    total_time_encoding = end_time_encoding - end_time_learning
    total_time_encoding = str(datetime.timedelta(seconds=total_time_encoding))
    print_stats(args.stats_file,
                "Total time encoding =" + str(total_time_encoding) + "\n")
def calculate_stats(folder_enc, segmentation_name, clustering_final_name, apply_mask_outliers=True, S2=False):
    print("S2", S2)
    stats_file = path_main + folder_enc + 'stats.txt'
    path_cm = os.path.expanduser('~/Desktop/Datasets/occupation_des_sols/')
    # We open Corina Land Cover GT maps, they have 3 levels of precision
    # We combinate different classes to create a desired GT map
    cm_truth_name = "clc_2008_lvl1"
    cm_truth_name2 = "clc_2008_lvl2"
    cm_truth_name3 = "clc_2008_lvl3"
    if S2:
        cm_truth_name = "clc_2017_lvl1"
        cm_truth_name2 = "clc_2017_lvl2"
        cm_truth_name3 = "clc_2017_lvl3"
    cm_truth, H, W, geo, proj, _ = open_tiff(path_cm, cm_truth_name)
    cm_truth2, _, _, _, _, _ = open_tiff(path_cm, cm_truth_name2)
    cm_truth3, _, _, _, _, _ = open_tiff(path_cm, cm_truth_name3)

    cm_truth = cm_truth.flatten()
    cm_truth2 = cm_truth2.flatten()
    cm_truth3 = cm_truth3.flatten()


    cm_truth[cm_truth == 1] = 1 # city
    cm_truth[cm_truth == 2] = 1 # industrial area
    cm_truth[cm_truth == 3] = 1  # extractions des materiaux
    cm_truth[cm_truth == 4] = 6 #espaces vertes
    cm_truth[cm_truth3 == 511] = 6 #Jardins familiaux
    cm_truth[cm_truth3 == 512] = 6 #Espaces libres urbains
    cm_truth[cm_truth3 == 513] = 513 #Cultures annuelles
    cm_truth[cm_truth3 == 514] = 514  # Prairies
    cm_truth[cm_truth3 == 521] = 521    # vignes
    cm_truth[cm_truth3 == 522] = 522    # vergers
    cm_truth[cm_truth3 == 523] = 523    # oliveraies
    cm_truth[cm_truth == 6] = 6         #espaces boisés
    cm_truth[cm_truth == 7] = 7 #espaces non-boisés
    cm_truth[cm_truth == 8] = 8 #sea


    cm_truth[cm_truth3 == 240] = 0 #aeroport

    _, cm_truth_mod = np.unique(cm_truth, return_inverse=True)
    print(np.unique(cm_truth))


    ds = create_tiff(1, path_cm + cm_truth_name + "_custom", W, H,
                     gdal.GDT_Int16,
                     np.reshape(cm_truth_mod+1, (H,W)), geo, proj)
    vectorize_tiff(path_cm, cm_truth_name + "_custom", ds)
    ds.FlushCache()
    ds = None

    outliers_total, _, _, _, _, _ = open_tiff(path_main, "Outliers_total")
    mask = np.where(outliers_total.flatten() == 1)[0]

    for mean_or_median in ["mean", "median"]:
        print("Descriptor type " + mean_or_median)
        nmi_list = []
        ari_list = []
        print_stats(stats_file, "\n " + str("New classes"), print_to_console=True)

        print_stats(stats_file, "\n " + str(segmentation_name) + "_" + str(clustering_final_name), print_to_console=True)
        for cl in range(8, 16):
            print("Clusters="+str(cl))

            image_name_clust = clustering_final_name + "_" + mean_or_median + "_" + str(cl)
            image_array_cl, H, W, geo, proj, _ = open_tiff(path_main + folder_enc + segmentation_name + "/" + clustering_final_name + "/", image_name_clust)
            cm_predicted = image_array_cl.flatten()
            cm_truth = cm_truth_mod

            ind = np.where(cm_predicted<0)[0]
            if len(ind)==1:
                cm_predicted[-1] = cm_predicted[-2]
            if apply_mask_outliers == True:
                ind = np.intersect1d(mask, np.where(cm_truth>0)[0])
            else:
                ind = np.where(cm_truth > 0)[0]

            cm_truth = cm_truth[ind]
            cm_predicted = cm_predicted[ind]

            nmi = normalized_mutual_info_score(cm_truth, cm_predicted)
            ari = adjusted_rand_score(cm_truth, cm_predicted)
            print(nmi)
            print(ari)

            nmi_list.append(np.round(nmi,2))
            ari_list.append(np.round(ari,2))


        if apply_mask_outliers:
            print_stats(stats_file, mean_or_median + " WITH MASK", print_to_console=True)
        else:
            print_stats(stats_file, mean_or_median + " WITHOUT MASK", print_to_console=True)
        print_stats(stats_file, "NMI", print_to_console=True)
        print_stats(stats_file, str(nmi_list), print_to_console=True)
        print_stats(stats_file, "ARI", print_to_console=True)
        print_stats(stats_file, str(ari_list), print_to_console=True)