def create(self):                       # Function to create the frame
     if self.display.winfo_exists():
         self.display.grid_forget()
         self.display.destroy()
         
     self.display = ttk.Frame(self)
     self.display.grid(row = 0, column = 4,rowspan=2, sticky = 'nwes')
     frame=self.display
     ff=[]
     for ffile in self.files:
         ff.append(rasterio.open(ffile))
     mos, out = merge(ff, method = self.dropdown.get())
     self.figure = Figure(figsize = (6,4), dpi = 100)
     self.plot = self.figure.add_subplot(1,1,1)
     #show(mos, cmap='terrain' , ax = self.plot)
     ep.plot_rgb((mos), stretch=True, str_clip = 0.5 , ax = self.plot)
     self.canvas = FigureCanvasTkAgg(self.figure,frame)
     self.toolbar = NavigationToolbar2Tk(self.canvas,frame)
     self.toolbar.update()
     self.canvas.get_tk_widget().pack()
     out_path = 'OutputImages/'+self.entry.get()+'.tif'
     out_meta = ff[0].meta.copy()
     out_meta.update({"driver": "GTiff","height": mos.shape[1],"width": mos.shape[2],"transform": out,"crs": "+proj=utm +zone=35 +ellps=GRS80 +units=m +no_defs "})
     with rasterio.open(out_path, 'w', **out_meta) as dst:
         dst.write(mos)
def go_plot(p_bbox, p_raster_data, p_raster_data_extent, family, species,
            points, p_xticks, p_jticks, p_xlabels, p_ylabels):
    f, ax = plt.subplots(figsize=(7, 7))
    ep.plot_rgb(
        p_raster_data.values,
        rgb=[0, 1, 2],
        ax=ax,
        #title="test Burundi",
        extent=p_raster_data_extent)

    p_points = geopandas.GeoDataFrame(geometry=points)
    p_points.set_crs(epsg=3857, inplace=True)
    p_points.plot(ax=ax, zorder=20, color="black")

    ax.set_xticks(p_xticks)
    ax.set_yticks(p_jticks)
    ax.set_xticklabels(p_xlabels)
    ax.set_yticklabels(p_ylabels)
    ax.set_xlim(p_bbox[0], p_bbox[1])
    ax.set_ylim(p_bbox[2], p_bbox[3])
    ax.add_artist(ScaleBar(dx=1, box_alpha=0.1, location='lower left'))
    print(family)
    print(species)
    #ax.set_title(family+ " - "+ species)
    plt.savefig(output_map + family + "_" + species + ".png", dpi=300)
    plt.close('all')
Ejemplo n.º 3
0
def save_imgSR(img, img_path, mode='RGB'):
    # print("**** img min: %f  , max: %f  "%(img.min(), img.max() ))
    # print("***  type: ", type(img))
    # print("***  SHAPE: ", img.shape)
    ep.plot_rgb(img, rgb=[0, 1, 2], title="SR_RGB", stretch=True)
    plt.savefig(img_path + "_RGB_SR.png")
    plt.close()
Ejemplo n.º 4
0
def test_1band(rgb_image):
    """Test to ensure the input image has 3 bands to support rgb plotting.
    If fewer than 3 bands are provided, fail gracefully."""
    a_rgb_image, _ = rgb_image

    with pytest.raises(ValueError, match="Input needs to be 3 dimensions"):
        plot_rgb(a_rgb_image[1])
    plt.close()
Ejemplo n.º 5
0
def test_1band(rgb_image):
    """Test to ensure the input image has 3 bands to support rgb plotting.
    If fewer than 3 bands are provided, fail gracefully."""
    a_rgb_image, _ = rgb_image

    with pytest.raises(
        ValueError,
        match="""Input needs to be 3 dimensions and in rasterio
                           order with bands first""",
    ):
        plot_rgb(a_rgb_image[1])
Ejemplo n.º 6
0
def ColorComposition(bands, yr='', r=2, g=1, b=0):
	"""Plot true/false colour composite.

	True = [2, 1, 0] (r, g, b)
	The function plots the result, but does not automatically save the image!
	"""

	if r==2 and g == 1 and b==0:
		ep.plot_rgb(bands, rgb=[r,g,b],figsize=(7,7),\
			title="RGB image (True). R=%s, G=%s, B=%s. Year: %s" %(r+1, g+1, b+1, yr), stretch=True)

	else:
		ep.plot_rgb(bands, rgb=[r,g,b], figsize=(7,7),\
			title="RGB image (False). R=%s, G=%s, B=%s. Year: %s" %(r+1, g+1, b+1, yr), stretch=True)
Ejemplo n.º 7
0
 def showImage(self):  # Function to display the image in frame
     # directory = os.getcwd()
     # now = datetime.now()
     # dt_string = now.strftime("%d_%m_%Y_%H:%M:%S")
     # outdir = directory + '/img_STACK' + dt_string + '.tiff'
     outdir = 'OutputImages/' + self.out_name.get() + '.tiff'
     frame = self.frame
     self.band_fnames = [self.file1, self.file2, self.file3]
     arr_st, meta = es.stack(self.band_fnames, out_path=outdir, nodata=0)
     self.figure = Figure(figsize=(10, 6), dpi=100)
     self.plot = self.figure.add_subplot(1, 1, 1)
     ep.plot_rgb((arr_st), stretch=True, str_clip=0.5, ax=self.plot)
     self.canvas = FigureCanvasTkAgg(self.figure, frame)
     self.toolbar = NavigationToolbar2Tk(self.canvas, frame)
     self.toolbar.update()
     self.canvas.get_tk_widget().pack()
Ejemplo n.º 8
0
def test_ticks_off(rgb_image):
    """Test that the output plot has ticks turned off. The ticks
    array should be empty (length == 0)."""

    im, _ = rgb_image

    f, ax = plot_rgb(im)
    assert len(ax.get_xticks()) == 0
    assert len(ax.get_yticks()) == 0
    plt.close(f)
Ejemplo n.º 9
0
def plot_ep_plot(imgs=None, stretch=None):
    """
    helper for visualization
    :param imgs: tuple of paths to image
    :param stretch: linear stretch bool
    :return: None
    """

    if not imgs:
        dir = os.environ['LS_MD_PAIRS']
        pairs = util.get_landsat_modis_pairs(dir,
                                             transform=True,
                                             both_modis=True)
        imgs = pairs[0]

    fig4, ax4 = plt.subplots(1, 2)
    # plot landsat
    with rio.open(imgs[0]) as l_src:
        img1 = l_src.read()
        ep.plot_rgb(
            img1,
            rgb=(3, 2, 1),
            figsize=(10, 10),
            str_clip=2,
            ax=ax4[0],
            extent=None,
            title="Landsat True Colour",
            stretch=stretch,
        )

    # Plot MODIS
    with rio.open(imgs[1]) as m_src:
        img2 = m_src.read()
        ep.plot_rgb(
            img2,
            rgb=(0, 3, 2),
            figsize=(10, 10),
            str_clip=2,
            ax=ax4[1],
            extent=None,
            title="MODIS True Colour",
            stretch=stretch,
        )
Ejemplo n.º 10
0
def test_ax_provided(rgb_image):
    """Test to ensure the plot works when an explicit axis is provided"""
    rgb_image, _ = rgb_image
    _, ax1 = plt.subplots()
    f, ax = plot_rgb(rgb_image, ax=ax1)

    rgb_im_shape = rgb_image.transpose([1, 2, 0]).shape
    the_plot_im_shape = ax.get_images()[0].get_array().shape
    assert rgb_im_shape == the_plot_im_shape
    plt.close(f)
Ejemplo n.º 11
0
def test_ax_not_provided(rgb_image):
    """Test plot_rgb produces an output image when an axis object is
    not provided."""

    rgb_image, _ = rgb_image
    f, ax = plot_rgb(rgb_image)
    rgb_im_shape = rgb_image.transpose([1, 2, 0]).shape
    the_plot_im_shape = ax.get_images()[0].get_array().shape
    assert rgb_im_shape == the_plot_im_shape
    plt.close(f)
Ejemplo n.º 12
0
def test_two_ax_provided(rgb_image):
    """Test to ensure the plot works when more than one axis is provided

    This test is being added because it turned out that the second plot
    was clearing given a call to plt.show and that wasn't being captured
    in the previous tests. """

    rgb_image, _ = rgb_image
    f, (ax1, ax2) = plt.subplots(2, 1)
    ax1_test = plot_rgb(rgb_image, ax=ax1)
    ax2_test = plot_rgb(rgb_image, ax=ax2)

    rgb_im_shape = rgb_image.transpose([1, 2, 0]).shape
    the_plot_im_shape = ax1_test.get_images()[0].get_array().shape
    the_plot_im_shape2 = ax2_test.get_images()[0].get_array().shape

    assert rgb_im_shape == the_plot_im_shape
    assert rgb_im_shape == the_plot_im_shape2
    plt.close()
Ejemplo n.º 13
0
def test_no_data_val(rgb_image):
    """An array with a nodata value that is stretched should plot."""

    a_rgb_image, _ = rgb_image
    a_rgb_image = a_rgb_image.astype("int16")
    a_rgb_image[a_rgb_image == 255] = -9999
    im = plot_rgb(a_rgb_image, stretch=True)

    assert len(im.get_images()) == 1
    plt.close()
Ejemplo n.º 14
0
def plot_raster_pair():
    """
    view landsat modis pair
    :param im1: str path
    :param im2:  str path
    :return:
    """

    dir = os.environ['LS_MD_PAIRS']
    pairs = util.get_landsat_modis_pairs(dir)
    fig3, ax3 = plt.subplots(2, 2)

    with rio.open(glob(pairs[0][0] + "\*")[0]) as l_src:
        img1 = l_src.read()

        ax3[0, 0] = ep.plot_rgb(
            img1,
            rgb=[3, 2, 1],
            title="Landsat RGB Image\n Linear Stretch Applied",
            stretch=True,
            str_clip=4)

        ax3[0, 1] = ep.plot_rgb(img1,
                                rgb=[3, 2, 1],
                                title="Landsat RGB Image",
                                stretch=False)

    with rio.open(glob(pairs[0][0] + "\*")[1]) as m_src:
        img2 = m_src.read()

        ax3[1,
            0] = ep.plot_rgb(img2,
                             rgb=[3, 2, 1],
                             title="MODIS RGB Image\n Linear Stretch Applied",
                             stretch=True,
                             str_clip=4)

        ax3[1, 1] = ep.plot_rgb(img2,
                                rgb=[3, 2, 1],
                                title="MODIS RGB Image",
                                stretch=False)

        plt.show()
Ejemplo n.º 15
0
def show_affine_transform(imgs=False, stretch=True):
    """
     helper for visualization
     :param imgs: tuple of paths to image
     :param stretch: linear stretch bool
     :return: None
    """
    if not imgs:
        dir = os.environ['LS_MD_PAIRS']
        pairs = util.get_landsat_modis_pairs(dir,
                                             transform=True,
                                             both_modis=True)
        imgs = pairs[0]

    fig5, ax5 = plt.subplots(1, 2)
    # Plot MODIS untransformed
    with rio.open(pairs[0][0]) as m_src:
        img2 = m_src.read()
        ep.plot_rgb(
            img2,
            rgb=(0, 3, 2),
            figsize=(10, 10),
            str_clip=2,
            ax=ax5[0],
            extent=None,
            title="MODIS Untransformed",
            stretch=stretch,
        )

    with rio.open(pairs[0][1]) as m_src:
        img2 = m_src.read()
        ep.plot_rgb(
            img2,
            rgb=(0, 3, 2),
            figsize=(10, 10),
            str_clip=2,
            ax=ax5[1],
            extent=None,
            title="MODIS transformed",
            stretch=stretch,
        )

    plt.show()
Ejemplo n.º 16
0
def test_stretch_image(rgb_image):
    """Test that running stretch actually stretches the data
    to a max value of 255 within the plot_rb fun."""

    im, _ = rgb_image
    np.place(im, im > 150, [0])

    f, ax = plot_rgb(im, stretch=True)
    max_val = ax.get_images()[0].get_array().max()
    assert max_val == 255
    plt.close(f)
Ejemplo n.º 17
0
def test_masked_im(rgb_image):
    """Test that a masked image will be plotted using an alpha channel.
    Thus it should return an array that has a 4th dimension representing
    the alpha channel."""

    im, _ = rgb_image
    im_ma = ma.masked_where(im > 140, im)

    f, ax = plot_rgb(im_ma)
    im_plot = ax.get_images()[0].get_array()
    assert im_plot.shape[2] == 4
    plt.close(f)
Ejemplo n.º 18
0
def test_stretch_output_scaled(rgb_image):
    """Test that stretch changes the array mean

    For n unique str_clip values, we expect n unique array means.
    """
    arr, _ = rgb_image
    stretch_vals = list(range(10))
    mean_vals = list()
    for v in stretch_vals:
        ax = plot_rgb(arr, stretch=True, str_clip=v)
        mean = ax.get_images()[0].get_array().mean()
        mean_vals.append(mean)
        plt.close()
    assert len(set(mean_vals)) == len(stretch_vals)
Ejemplo n.º 19
0
def test_stretch_image_nan(rgb_image):
    """Test that running stretch actually stretches the data
    to a max value of 255 and min value of 0 when nan values
    are present within the plot_rgb fun."""

    im, _ = rgb_image
    np.place(im, im > 150, [0])
    im = np.where(im < 25, np.nan, im)

    ax = plot_rgb(im, stretch=True)
    max_val = ax.get_images()[0].get_array().max()
    min_val = ax.get_images()[0].get_array().min()
    assert max_val == 255 and min_val == 0
    plt.close()
Ejemplo n.º 20
0
def plot_train_data(X_DICT_TRAIN, Y_DICT_TRAIN, image_number = 12):
    
    labels =['Orginal Image with the 8 bands', 'Ground Truths: Buildings', 'Ground Truths: Roads & Tracks', 'Ground Truths: Trees' , 'Ground Truths: Crops', 'Ground Truths: Water']
    
    image_number = str(image_number).zfill(2)
    number_of_GTbands = Y_DICT_TRAIN[image_number].shape[2]
    f, axarr = plt.subplots(1, number_of_GTbands + 1, figsize=(25,25))

    band_indices = [0, 1, 2]
    print('Image shape is: ',X_DICT_TRAIN[image_number].shape)
    print("Ground Truth's shape is: ",Y_DICT_TRAIN[image_number].shape)

    ep.plot_rgb(X_DICT_TRAIN[image_number].transpose([2,0,1]),
                rgb=band_indices,
                title=labels[0],
                stretch=True,
                ax=axarr[0])
    
    for i in range(0, number_of_GTbands):
        axarr[i+1].imshow(Y_DICT_TRAIN[image_number][:,:,i])
        #print(labels[i+1])
        axarr[i+1].set_title(labels[i+1])

    plt.show()
Ejemplo n.º 21
0
def test_stretch_output_scaled(rgb_image):
    """Test that stretch changes the array mean

    For n unique str_clip values, we expect n unique array means.
    """
    arr, _ = rgb_image
    stretch_vals = list(range(10))
    axs = [plot_rgb(arr, stretch=True, str_clip=v)[1] for v in stretch_vals]
    mean_vals = np.array([ax.get_images()[0].get_array().mean() for ax in axs])
    n_unique_means = np.unique(mean_vals).shape[0]
    assert n_unique_means == len(stretch_vals)
    try:
        axs
    finally:
        del axs
Ejemplo n.º 22
0
def test_rgb_extent(rgb_image):
    """Test to ensure that when the extent is provided, plot_rgb stretches
     the image or applies the proper x and y lims. Also ensure that the
     correct bands are plotted an in the correct order when the rgb
     param is called and defined. Finally test that a provided title and
     figsize created a plot with the correct title and figsize"""
    a_rgb_image, ext = rgb_image
    f, ax = plot_rgb(
        a_rgb_image,
        extent=ext,
        rgb=(1, 2, 0),
        title="My Title",
        figsize=(5, 5),
    )
    # Get x and y lims to test extent
    plt_ext = ax.get_xlim() + ax.get_ylim()

    plt_array = ax.get_images()[0].get_array()

    assert f.bbox_inches.bounds[2:4] == (5, 5)
    assert ax.get_title() == "My Title"
    assert np.array_equal(plt_array[0], a_rgb_image.transpose([1, 2, 0])[1])
    assert ext == plt_ext
    plt.close(f)
Ejemplo n.º 23
0
def plotStuff(stack):
    fig, ax = plt.subplots(figsize=(12, 12))
    # Plot red, green, and blue bands, respectively
    ep.plot_rgb(stack,
                rgb=(3, 2, 1),
                ax=ax,
                title="Landsat 8 RGB Image",
                stretch=True)
    # FalseColor
    fig, ax = plt.subplots(figsize=(12, 12))
    ep.plot_rgb(stack,
                rgb=(4, 2, 1),
                ax=ax,
                title="Landsat 8 CIR Image",
                stretch=True)
    fig, ax = plt.subplots(figsize=(12, 12))
    ep.plot_rgb(stack,
                rgb=(3, 2, 0),
                ax=ax,
                title="Landsat 8 CUV Image",
                stretch=True)
Ejemplo n.º 24
0
def main():
    PreUp = False

    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.
    ratio = opt["scale"]
    if PreUp == True:
        ratio=5

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None, opt['path']['log'], 'train', level=logging.INFO, screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger_train = SummaryWriter(log_dir='/mnt/gpid07/users/luis.salgueiro/git/mnt/BasicSR/tb_logger/' + opt['name'] + "/train")
        tb_logger_val = SummaryWriter(log_dir='//mnt/gpid07/users/luis.salgueiro/git/mnt/BasicSR/tb_logger/' + opt['name'] + "/val" )


    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = 100  #random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True
    # print("OLAAAA_-...", os.environ['CUDA_VISIBLE_DEVICES'])

# #########################################
# ######## DATA LOADER ####################
# #########################################
    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            print("Entro DATASET train......")
            train_set = create_dataset(dataset_opt)
            print("CREO DATASET train_set ", train_set)

            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
            print("CREO train loader: ", train_loader)
        elif phase == 'val':
            print("Entro en phase VAL....")
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'],
                                                                      len(val_set)))
            # for _,ii in enumerate(val_loader):
            #     print("VAL LOADER:........", ii)
            # print(val_loader[0])
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    assert val_loader is not None

    # create model
    model = create_model(opt)
    #print("PASO.....   MODEL ")

    # resume training
    if resume_state:
        print("RESUMING state")
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("PASO.....   INIT ")

    # #########################################
    # #########    training    ################
    # #########################################
    # ii=0
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        # print("Entro EPOCH...", ii)
        for _, train_data in enumerate(train_loader):


            # print("Entro TRAIN_LOADER...")
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            #print("....... TRAIN DATA..........", train_data)
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log train
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                # print(".............MESSAGE: ", message)
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    #print("MSG: ", message)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        # print("K: ", k)
                        # print("V: ", v)
                        if "test" in k:
                            tb_logger_val.add_scalar(k, v, current_step)
                        else:
                            tb_logger_train.add_scalar(k, v, current_step)
                logger.info(message)

            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr_sr  = 0.0
                avg_psnr_lr  = 0.0
                avg_psnr_dif = 0.0
                avg_ssim_lr, avg_ssim_sr, avg_ssim_dif    = 0.0, 0.0, 0.0
                avg_ergas_lr, avg_ergas_sr, avg_ergas_dif = 0.0, 0.0, 0.0
                idx = 0
                # for val_data in val_loader:
                for _, val_data in enumerate(val_loader):
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    # print("Img nameVaL: ", img_name)


                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()

                    sr_img = util.tensor2imgNorm(visuals['SR'],out_type=np.uint8, min_max=(0, 1), MinVal=val_data["LR_min"], MaxVal=val_data["LR_max"])  # uint16
                    gt_img = util.tensor2imgNorm(visuals['HR'],out_type=np.uint8, min_max=(0, 1), MinVal=val_data["HR_min"], MaxVal=val_data["HR_max"])  # uint16
                    lr_img = util.tensor2imgNorm(visuals['LR'], out_type=np.uint8, min_max=(0, 1),
                                                 MinVal=val_data["LR_min"], MaxVal=val_data["LR_max"])  # uint16

                    # Save SR images for reference
                    if idx < 10:
                        # print(idx)
                        util.mkdir(img_dir)
                        save_img_path = os.path.join(img_dir, '{:s}_{:d}'.format(img_name, current_step))
                        util.save_imgSR(sr_img, save_img_path)
                        util.save_imgHR(gt_img, save_img_path)
                        util.save_imgLR(lr_img, save_img_path)
                        print("SAVING CROPS")
                        util.save_imgCROP(lr_img,gt_img,sr_img , save_img_path, ratio, PreUp=PreUp)

                    if PreUp==False:
                        dim2 = (gt_img.shape[1], gt_img.shape[1])
                        print("DIM:", dim2)
                        print("LR image shape ", lr_img.shape)
                        print("HR image shape ", gt_img.shape)
                        lr_img = cv2.resize(np.transpose(lr_img,(1,2,0)), dim2, interpolation=cv2.INTER_NEAREST)
                        lr_img = np.transpose(lr_img,(2,0,1))
                        print("LR image 2 shape ", lr_img.shape)
                        print("LR image 2 shape ", lr_img.shape)

                    avg_psnr_sr += util.calculate_psnr2(sr_img, gt_img)
                    avg_psnr_lr += util.calculate_psnr2(lr_img, gt_img)
                    avg_ssim_lr += util.calculate_ssim2(lr_img, gt_img)
                    avg_ssim_sr += util.calculate_ssim2(sr_img, gt_img)
                    avg_ergas_lr += util.calculate_ergas(lr_img, gt_img, pixratio=ratio)
                    avg_ergas_sr += util.calculate_ergas(sr_img, gt_img, pixratio=ratio)
                    #avg_psnr += util.calculate_psnr2(cropped_sr_img, cropped_gt_img)


                avg_psnr_sr = avg_psnr_sr / idx
                avg_psnr_lr = avg_psnr_lr / idx
                avg_psnr_dif = avg_psnr_lr - avg_psnr_sr
                avg_ssim_lr = avg_ssim_lr / idx
                avg_ssim_sr = avg_ssim_sr / idx
                avg_ssim_dif = avg_ssim_lr - avg_ssim_sr
                avg_ergas_lr  = avg_ergas_lr / idx
                avg_ergas_sr  = avg_ergas_sr / idx
                avg_ergas_dif = avg_ergas_lr - avg_ergas_sr
                # print("IDX: ", idx)

                # log VALIDATION
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr_sr))
                logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim_sr))
                logger.info('# Validation # ERGAS: {:.4e}'.format(avg_ergas_sr))

                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_SR: {:.4e}'.format(
                    epoch, current_step, avg_psnr_sr))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_LR: {:.4e}'.format(
                    epoch, current_step, avg_psnr_lr))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr_DIF: {:.4e}'.format(
                    epoch, current_step, avg_psnr_dif))

                logger_val.info('<epoch:{:3d}, iter:{:8,d}> ssim_LR: {:.4e}'.format(
                    epoch, current_step, avg_ssim_lr))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> ssim_SR: {:.4e}'.format(
                    epoch, current_step, avg_ssim_sr))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> ssim_DIF: {:.4e}'.format(
                    epoch, current_step, avg_ssim_dif))

                logger_val.info('<epoch:{:3d}, iter:{:8,d}> ergas_LR: {:.4e}'.format(
                    epoch, current_step, avg_ergas_lr))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> ergas_SR: {:.4e}'.format(
                    epoch, current_step, avg_ergas_sr))
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> ergas_DIF: {:.4e}'.format(
                    epoch, current_step, avg_ergas_dif))

                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger_val.add_scalar('dif_PSNR', avg_psnr_dif, current_step)
                    # tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger_val.add_scalar('dif_SSIM', avg_ssim_dif, current_step)
                    tb_logger_val.add_scalar('dif_ERGAS', avg_ergas_dif, current_step)

                    tb_logger_val.add_scalar('psnr_LR', avg_psnr_lr, current_step)
                    # tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger_val.add_scalar('ssim_LR', avg_ssim_lr, current_step)
                    tb_logger_val.add_scalar('ERGAS_LR', avg_ergas_lr, current_step)

                    tb_logger_val.add_scalar('psnr_SR', avg_psnr_sr, current_step)
                    # tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger_val.add_scalar('ssim_SR', avg_ssim_sr, current_step)
                    tb_logger_val.add_scalar('ERGAS_SR', avg_ergas_sr, current_step)


                    print("****** SR_IMG: ", sr_img.shape)
                    print("****** LR_IMG: ", lr_img.shape)
                    print("****** GT_IMG: ", gt_img.shape)

                    fig1,ax1 = ep.plot_rgb(sr_img, rgb=[2, 1, 0], stretch=True)
                    tb_logger_val.add_figure("SR_plt", fig1, current_step,close=True)
                    fig2, ax2 = ep.plot_rgb(gt_img, rgb=[2, 1, 0], stretch=True)
                    tb_logger_val.add_figure("GT_plt", fig2, current_step, close=True)
                    fig3, ax3 = ep.plot_rgb(lr_img, rgb=[2, 1, 0], stretch=True)
                    tb_logger_val.add_figure("LR_plt", fig3, current_step, close=True)
                    # print("TERMINO GUARDAR IMG TB")
            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)
        # ii=ii+1
    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
Ejemplo n.º 25
0
def main():
    xaxis, ds = [], ""
    # Finds the .nc file in directory
    ds = findFile()

    # Uses netcdf4 library to open the file
    ds = netCDF4.Dataset(ds, "r")
    # Gets the dimensions of data
    lats, lons, depth2 = ds.dimensions['x'], ds.dimensions['y'], ds.dimensions[
        'Bands']

    # Debug
    print("Max x = " + str(lons.size - 1))
    print("Max y = " + str(lats.size - 1))
    print(depth2.size)
    #initialize variables

    depth, arr1, xaxis, a = depth2.size, [], list(range(1, depth2.size)), 0
    R, G, B = 29, 21, 16

    arrmain = np.zeros(ds.variables["Data"].shape)
    arrtemp = ds.variables["Data"][:]

    # Apply transformation to the arrays so it's right orientation
    for c in range(0, depth):
        arrmain[c] = np.fliplr(np.rot90(arrtemp[c, :, :], 2))
    print(np.max(arrmain[:, 650, 850]))
    # Create new plot and insert data
    fig, ax = plt.subplots(figsize=(6, 6))
    # Uses the RGB bands, 29, 21, 16
    ep.plot_rgb(arrmain, rgb=(29, 21, 16), ax=ax, title="HyperSpectral Image")
    fig.tight_layout()

    # Function event listenter, once detected right click, will generat nwe graph
    def onclick(event):
        if event.xdata is not None and event.ydata is not None and event.button == 3:
            # (x,y) from click
            y = int(event.xdata)
            x = int(event.ydata)

            arr2 = []
            print(depth)
            # Gets data from the click location (x,y) for all bands
            for b in range(1, depth):
                data = arrmain[b][x][y]
                arr2.append(data)

            # Creates new graph
            fig2 = go.Figure()
            fig2.add_trace(
                go.Scatter(x=xaxis,
                           y=arr2,
                           name="(" + str(x) + "," + str(y) + ")",
                           line=dict(color='firebrick', width=2)))
            fig2.update_layout(title="Band Information for " + "(" + str(x) +
                               "," + str(y) + ")",
                               xaxis_title='Band Number',
                               yaxis_title='Value')
            # displays graph onto web browser
            fig2.show()

    # Adds event listener for right clicks
    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()
Ejemplo n.º 26
0
###############################################################################
# Plot RGB Composite Image
# --------------------------
# You can use the ``plot_rgb()`` function from the ``earthpy.plot`` module to quickly
# plot three band composite images. For RGB composite images, you will plot the red,
# green, and blue bands, which are bands 4, 3, and 2, respectively, in the image
# stack you created. Python uses a zero-based index system, so you need to subtract
# a value of 1 from each index. Thus, the index for the red band is 3, green is 2,
# and blue is 1. These index values are provided to the ``rgb`` argument to identify
# the bands for the composite image.

# Create figure with one plot
fig, ax = plt.subplots(figsize=(12, 12))

# Plot red, green, and blue bands, respectively
ep.plot_rgb(arr_st, rgb=(3, 2, 1), ax=ax, title="Landsat 8 RGB Image")
plt.show()

###############################################################################
# Stretch Composite Images
# -------------------------
# Composite images can sometimes be dark if the pixel brightness values are
# skewed toward the value of zero. You can stretch the pixel brightness values
# in an image using the argument ``stretch=True`` to extend the values to the
# full 0-255 range of potential values to increase the visual contrast of the
# image. In addition, the ``str_clip`` argument allows you to specify how much of
# the tails of the data that you want to clip off. The larger the number, the
# more the data will be stretched or brightened.

# Create figure with one plot
fig, ax = plt.subplots(figsize=(12, 12))
Ejemplo n.º 27
0
import gdal
import sys
get_ipython().system('{sys.executable} -m pip install fastai')
import glob
from arcgis.gis import GIS
from arcgis.raster import ImageryLayer
from sentinelhub import SHConfig, MimeType, CRS, BBox, SentinelHubRequest, SentinelHubDownloadClient, DataSource, bbox_to_dimensions, DownloadRequest, BBoxSplitter, OsmSplitter, TileSplitter, CustomGridSplitter, UtmZoneSplitter, UtmGridSplitter
import itertools

# In[3]:

epaths = glob.glob("Sentinel/*.jp2")
epaths.sort()
epath = ("Sentinel/T11SMT_20200717T182921_B01_60m.jp2")
arr_stack, metadata = es.stack(epaths)
ep.plot_rgb(arr_stack, rgb=[4, 3, 1], stretch=True, figsize=(20, 20))
plt.savefig('edata')

# In[2]:

#sentinel2 processing
config = SHConfig()
CLIENT_SECRET = 'm*JW}?-76bBH)PjZp:-sW,3ISibK)mfh0GPc])n^'
CLIENT_ID = 'edb4f750-7cb2-475c-b190-3406e33de291'
config.sh_client_id = CLIENT_ID
config.sh_client_secret = CLIENT_SECRET

usa_bbox = -118.572693, 34.002581, -118.446350, 34.057211
resolution = 60
bbox = BBox(bbox=usa_bbox, crs=CRS.WGS84)
size = bbox_to_dimensions(bbox, resolution=resolution)
Ejemplo n.º 28
0
def plot_objects(obj=None,
                 img=None,
                 column=None,
                 bounds_only=True,
                 obj_extent=True,
                 obj_cmap=None,
                 linewidth=0.5,
                 alpha=1,
                 edgecolor='white',
                 rgb=[4, 2, 1],
                 band=None,
                 plot_window=None,
                 ax=None,
                 obj_kwargs={},
                 img_kwargs={}):
    """Plot vector objects on an image
    Parameters:
        """
    # Create a figure and ax is not provided
    if not ax:
        fig, ax = plt.subplots(1, 1, figsize=(15, 15))
    # Plot the img if provided
    if img is not None:
        logger.debug('Plotting imagery...')
        # If path to image provided, open it, otherwise assumed open rasterio DatasetReader
        if isinstance(img, str):
            img = rio.open(img)
        img_arr = img.read(masked=True)

        if obj is not None and obj_extent:
            logger.debug('Using objects extent for plotting.')
            minx, miny, maxx, maxy = obj.total_bounds
            minrow, mincol = geo2pixel(y=maxy, x=minx, img=img)
            maxrow, maxcol = geo2pixel(y=miny, x=maxx, img=img)
            img_arr = img_arr[:, mincol:maxcol, minrow:maxrow]
            logger.debug("Object ext: {} {} {} {}".format(
                minx, miny, maxx, maxy))
            img_ext = (minx, miny, maxx, maxy)
        else:
            logger.debug('Using images full extent for plotting.')
            # minx, maxx, miny, maxy = plotting_extent(img)
            # img_ext = (minx, miny, maxx, maxy)
            img_ext = plotting_extent(img)

        if not band and img_arr.shape[0] == 1:
            band = 1
        if band is not None:
            ep.plot_bands(img_arr[band - 1],
                          extent=img_ext,
                          ax=ax,
                          **img_kwargs)
        else:
            ep.plot_rgb(img_arr, rgb=rgb, ax=ax, extent=img_ext, **img_kwargs)

    # Plot the objects if provided
    if obj is not None:
        logger.debug('Plotting objects...')
        if img is not None:
            obj = obj.to_crs(img.crs)
        # If a column is not provided, plot all as the same color
        if column is None:
            # Create temporary column, ensuring it doesn't exist
            logger.debug('Creating a temporary column for plotting')
            column = np.random.randint(10000)
            while column in list(obj):
                column = 'temp_{}'.format(np.random.randint(10000))

            obj[column] = 1
        if bounds_only:
            logger.debug('Plotting objects...')
            obj.set_geometry(obj.geometry.boundary).plot(ax=ax,
                                                         column=column,
                                                         cmap=obj_cmap,
                                                         alpha=alpha,
                                                         linewidth=linewidth,
                                                         **obj_kwargs)
        else:
            obj.plot(ax=ax,
                     column=column,
                     cmap=obj_cmap,
                     alpha=alpha,
                     linewidth=linewidth,
                     edgecolor=edgecolor,
                     **obj_kwargs)
    if plot_window:
        logger.debug('Updating to use passed window as extent.')
        ax.set_xlim([plot_window[0], plot_window[2]])
        ax.set_ylim([plot_window[1], plot_window[3]])

    fig.show()

    return fig, ax
Ejemplo n.º 29
0
selected = np.random.choice(num_of_train_images, num_images)

fig = plt.figure(figsize = (25,25))
print('-------------multispectral images----------------')

for i, ind in enumerate(selected):
  raster_arr = rasterio.open(os.path.join(train_path, train_files[ind])).read()
  channels = raster_arr.shape[0]
  print("image : " , train_files[ind], "channels :", channels)
  #false color composite visualisation
  if(channels == 8):
    rgb = (4,2,1)             #R, G, B bands
  elif(channels == 4):
    rgb = (2,1,0)             #R, G, B bands

  ep.plot_rgb(raster_arr, rgb=rgb, stretch=True, title=train_files[ind])
  plt.show()
  #plt.imshow(raster_fcc)
  #plt.title(train_files[ind])

"""Loading the dataset from the split dataset

referred from:\
https://github.com/tkshnkmr/frcnn_medium_sample,
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

Checking the proportion of data classes in the dataset
"""

dataFrame = pd.read_csv(train_labels, usecols=["Idx", "Image Index", "IsLandfill"])
data = dataFrame.values.tolist()
Ejemplo n.º 30
0
def plot_images(t1, t2, t3, t4, t5, target, predicted, fusion, input_imag, cnt,
                psnr_output, ssim_out, psnr_input, ssim_in):

    fig = plt.figure(constrained_layout=True)

    gs = fig.add_gridspec(3, 5)

    f_ax1 = fig.add_subplot(gs[0, 0])
    f_ax2 = fig.add_subplot(gs[0, 1])
    f_ax3 = fig.add_subplot(gs[0, 2])
    f_ax4 = fig.add_subplot(gs[0, 3])
    f_ax5 = fig.add_subplot(gs[0, 4])
    f_ax6 = fig.add_subplot(gs[1, 0])
    f_ax7 = fig.add_subplot(gs[1, 1])
    f_ax8 = fig.add_subplot(gs[1, 2])
    f_ax9 = fig.add_subplot(gs[1, 3])
    f_ax10 = fig.add_subplot(gs[1, 4])
    f_ax11 = fig.add_subplot(gs[2, 0])
    f_ax12 = fig.add_subplot(gs[2, 1])
    f_ax13 = fig.add_subplot(gs[2, 2])
    f_ax14 = fig.add_subplot(gs[2, 3])
    f_ax15 = fig.add_subplot(gs[2, 4])

    ep.plot_rgb(t1, rgb=(2, 1, 0), ax=f_ax1, title="Input 1")
    ep.plot_rgb(t2, rgb=(2, 1, 0), ax=f_ax2, title="Input 2")
    ep.plot_rgb(t3, rgb=(2, 1, 0), ax=f_ax3, title="Input 3")
    ep.plot_rgb(t4, rgb=(2, 1, 0), ax=f_ax4, title="Input 4")
    ep.plot_rgb(t5, rgb=(2, 1, 0), ax=f_ax5, title="Input 5")
    ep.plot_rgb(predicted,
                rgb=(2, 1, 0),
                ax=f_ax8,
                title="SR({:.2f}, {:.2f})".format(psnr_output, ssim_out))
    ep.plot_rgb(target, rgb=(2, 1, 0), ax=f_ax6, title="HR(PSNR/SSIM)")
    ep.plot_rgb(target, rgb=(3, 2, 1), ax=f_ax7, title="NIR, R, G")
    ep.plot_rgb(predicted, rgb=(3, 2, 1), ax=f_ax9,
                title="NIR, R, G")  # NIR R G
    ep.plot_rgb(input_imag,
                rgb=(2, 1, 0),
                ax=f_ax10,
                title="B+M({:.2f}, {:.2f})".format(psnr_input, ssim_in))
    ep.plot_rgb(fusion, rgb=(0, 0, 0), ax=f_ax12, title="FNet Red")
    ep.plot_rgb(fusion, rgb=(1, 1, 1), ax=f_ax13, title="FNet Green")
    ep.plot_rgb(fusion, rgb=(2, 2, 2), ax=f_ax14, title="FNet Blue")
    ep.plot_rgb(fusion, rgb=(3, 3, 3), ax=f_ax15, title="FNet NIR")
    ep.plot_rgb(fusion, rgb=(2, 1, 0), ax=f_ax11, title="FNet RGB")

    fig.savefig("Results" + str(cnt) + ".png")