예제 #1
0
파일: classify.py 프로젝트: johanez/yatsm
def classify(ctx, config, algo, job_number, total_jobs, resume):
    cfg = parse_config_file(config)

    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
                              cfg['dataset']['date_format'])
    nrow = get_image_attribute(df['filename'][0])[0]

    classifier = joblib.load(algo)

    # Split into lines and classify
    job_lines = distribute_jobs(job_number, total_jobs, nrow)
    logger.debug('Responsible for lines: {l}'.format(l=job_lines))

    start_time = time.time()
    logger.info('Starting to run lines')
    for job_line in job_lines:
        filename = get_output_name(cfg['dataset'], job_line)
        if not os.path.exists(filename):
            logger.warning('No model result found for line {l} '
                           '(file {f})'.format(l=job_line, f=filename))
            pass

        if resume and try_resume(filename):
            logger.debug('Already processed line {l}'.format(l=job_line))
            continue

        logger.debug('Classifying line {l}'.format(l=job_line))
        classify_line(filename, classifier)

    logger.debug('Completed {n} lines in {m} minutes'.format(
        n=len(job_lines),
        m=round((time.time() - start_time) / 60.0, 2))
    )
예제 #2
0
파일: line.py 프로젝트: johanez/yatsm
def line(ctx, config, job_number, total_jobs,
         resume, check_cache, do_not_run, verbose_yatsm):
    if verbose_yatsm:
        logger_algo.setLevel(logging.DEBUG)

    # Parse config
    cfg = parse_config_file(config)

    if ('phenology' in cfg and cfg['phenology'].get('enable')) and not pheno:
        click.secho('Could not import yatsm.phenology but phenology metrics '
                    'are requested', fg='red')
        click.secho('Error: %s' % pheno_exception, fg='red')
        raise click.Abort()

    # Make sure output directory exists and is writable
    output_dir = cfg['dataset']['output']
    try:
        os.makedirs(output_dir)
    except OSError as e:
        # File exists
        if e.errno == 17:
            pass
        elif e.errno == 13:
            click.secho('Cannot create output directory %s' % output_dir,
                        fg='red')
            raise click.Abort()

    if not os.access(output_dir, os.W_OK):
        click.secho('Cannot write to output directory %s' % output_dir,
                    fg='red')
        raise click.Abort()

    # Test existence of cache directory
    read_cache, write_cache = test_cache(cfg['dataset'])

    logger.info('Job {i} of {n} - using config file {f}'.format(i=job_number,
                                                                n=total_jobs,
                                                                f=config))
    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
                              cfg['dataset']['date_format'])
    df['image_ID'] = get_image_IDs(df['filename'])

    # Get attributes of one of the images
    nrow, ncol, nband, dtype = get_image_attribute(df['filename'][0])

    # Calculate the lines this job ID works on
    job_lines = distribute_jobs(job_number, total_jobs, nrow)
    logger.debug('Responsible for lines: {l}'.format(l=job_lines))

    # Calculate X feature input
    dates = np.asarray(df['date'])
    kws = {'x': dates}
    kws.update(df.to_dict())
    X = patsy.dmatrix(cfg['YATSM']['design_matrix'], kws)
    cfg['YATSM']['design'] = X.design_info.column_name_indexes

    # Form YATSM class arguments
    fit_indices = np.arange(cfg['dataset']['n_bands'])
    if cfg['dataset']['mask_band'] is not None:
        fit_indices = fit_indices[:-1]

    if cfg['YATSM']['reverse']:
        X = np.flipud(X)

    # Create output metadata to save
    md = {
        'YATSM': cfg['YATSM'],
        cfg['YATSM']['algorithm']: cfg[cfg['YATSM']['algorithm']]
    }
    if cfg['phenology']['enable']:
        md.update({'phenology': cfg['phenology']})

    # Begin process
    start_time_all = time.time()
    for line in job_lines:
        out = get_output_name(cfg['dataset'], line)

        if resume:
            try:
                np.load(out)
            except:
                pass
            else:
                logger.debug('Already processed line %s' % line)
                continue

        logger.debug('Running line %s' % line)
        start_time = time.time()

        Y = read_line(line, df['filename'], df['image_ID'], cfg['dataset'],
                      ncol, nband, dtype,
                      read_cache=read_cache, write_cache=write_cache,
                      validate_cache=False)
        if do_not_run:
            continue
        if cfg['YATSM']['reverse']:
            Y = np.fliplr(Y)

        output = []
        for col in np.arange(Y.shape[-1]):
            _Y = Y.take(col, axis=2)
            # Mask
            idx_mask = cfg['dataset']['mask_band'] - 1
            valid = cyprep.get_valid_mask(
                _Y,
                cfg['dataset']['min_values'],
                cfg['dataset']['max_values']).astype(bool)

            valid *= np.in1d(_Y.take(idx_mask, axis=0),
                             cfg['dataset']['mask_values'],
                             invert=True).astype(np.bool)

            _Y = np.delete(_Y, idx_mask, axis=0)[:, valid]
            _X = X[valid, :]
            _dates = dates[valid]

            # Run model
            cls = cfg['YATSM']['algorithm_cls']
            algo_cfg = cfg[cfg['YATSM']['algorithm']]

            yatsm = cls(lm=cfg['YATSM']['prediction_object'],
                        **algo_cfg.get('init', {}))
            yatsm.px = col
            yatsm.py = line

            try:
                yatsm.fit(_X, _Y, _dates, **algo_cfg.get('fit', {}))
            except TSLengthException:
                continue

            if yatsm.record is None or len(yatsm.record) == 0:
                continue

            # Postprocess
            if cfg['YATSM'].get('commission_alpha'):
                yatsm.record = postprocess.commission_test(
                    yatsm, cfg['YATSM']['commission_alpha'])

            for prefix, lm in zip(cfg['YATSM']['refit']['prefix'],
                                  cfg['YATSM']['refit']['prediction_object']):
                yatsm.record = postprocess.refit_record(yatsm, prefix, lm,
                                                        keep_regularized=True)

            if cfg['phenology']['enable']:
                pcfg = cfg['phenology']
                ltm = pheno.LongTermMeanPhenology(**pcfg.get('init', {}))
                yatsm.record = ltm.fit(yatsm, **pcfg.get('fit', {}))

            output.extend(yatsm.record)

        logger.debug('    Saving YATSM output to %s' % out)
        np.savez(out,
                 record=np.array(output),
                 version=__version__,
                 metadata=md)

        run_time = time.time() - start_time
        logger.debug('Line %s took %ss to run' % (line, run_time))

    logger.info('Completed {n} lines in {m} minutes'.format(
                n=len(job_lines),
                m=round((time.time() - start_time_all) / 60.0, 2)))
예제 #3
0
def annual(row1, row2, pct):
    NDV = -9999

    # EXAMPLE IMAGE for dimensions, map creation
    #example_img_fn = '/projectnb/landsat/users/valpasq/LCMS/stacks/p035r032/images/example_img'
    example_img_fn = '/projectnb/landsat/projects/Massachusetts/p012r031/images/example_img'

    # YATSM CONFIG FILE
    #config_file = '/projectnb/landsat/users/valpasq/LCMS/stacks/p035r032/p035r032_config_LCMS.yaml'
    config_file = '/projectnb/landsat/projects/Massachusetts/p012r031/p012r031_config_pixel.yaml'

    #WRS2 = 'p027r027'
    WRS2 = 'p012r031'

    # Up front -- declare hard coded dataset attributes (for now)
    BAND_NAMES = [
        'blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'therm', 'tcb', 'tcg',
        'tcw', 'fmask'
    ]
    n_band = len(BAND_NAMES) - 1
    col_names = [
        'date', 'blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'therm',
        'tcb', 'tcg', 'tcw'
    ]
    dtype = np.int16
    years = range(1985, 2016, 1)
    length = 33  # number of years

    # Read in example image
    example_img = read_image(example_img_fn)
    py_dim = example_img.shape[0]
    px_dim = example_img.shape[1]
    print('Shape of example image:')
    print(example_img.shape)

    # Read in and parse config file
    cfg = yaml.load(open(config_file))
    # List to np.ndarray so it works with cyprep.get_valid_mask
    cfg['dataset']['min_values'] = np.asarray(cfg['dataset']['min_values'])
    cfg['dataset']['max_values'] = np.asarray(cfg['dataset']['max_values'])

    # Get files list
    df = csvfile_to_dataframe(cfg['dataset']['input_file'], \
                              date_format=cfg['dataset']['date_format'])

    # Get dates for image stack
    df['image_ID'] = get_image_IDs(df['filename'])
    df['x'] = df['date']
    dates = df['date'].values

    # Initialize arrays for storing stats
    mean_TCB = np.zeros((py_dim, px_dim, length))
    mean_TCG = np.zeros((py_dim, px_dim, length))
    mean_TCW = np.zeros((py_dim, px_dim, length))

    min_val_TCB = np.zeros((py_dim, px_dim, length))
    min_val_TCG = np.zeros((py_dim, px_dim, length))
    min_val_TCW = np.zeros((py_dim, px_dim, length))

    min_idx_TCB = np.zeros((py_dim, px_dim, length))
    min_idx_TCG = np.zeros((py_dim, px_dim, length))
    min_idx_TCW = np.zeros((py_dim, px_dim, length))

    max_val_TCB = np.zeros((py_dim, px_dim, length))
    max_val_TCG = np.zeros((py_dim, px_dim, length))
    max_val_TCW = np.zeros((py_dim, px_dim, length))

    max_idx_TCB = np.zeros((py_dim, px_dim, length))
    max_idx_TCG = np.zeros((py_dim, px_dim, length))
    max_idx_TCW = np.zeros((py_dim, px_dim, length))

    for py in range(row1, row2):  # row iterator
        print('Working on row {py}'.format(py=py))
        sys.stdout.flush()
        start_time = time.time()

        Y_row = read_line(
            py,
            df['filename'],
            df['image_ID'],
            cfg['dataset'],
            px_dim,
            n_band + 1,
            dtype,  # +1 for now for Fmask
            read_cache=False,
            write_cache=False,
            validate_cache=False)

        for px in range(0, px_dim):  # column iterator
            Y = Y_row.take(px, axis=2)

            if (Y[0:6] == NDV).mean() > 0.3:
                continue
            else:  # process time series for disturbance events

                # Mask based on physical constraints and Fmask
                valid = cyprep.get_valid_mask( \
                            Y, \
                            cfg['dataset']['min_values'], \
                            cfg['dataset']['max_values']).astype(bool)

                # Apply mask band
                idx_mask = cfg['dataset']['mask_band'] - 1
                valid *= np.in1d(Y.take(idx_mask, axis=0), \
                                         cfg['dataset']['mask_values'], \
                                         invert=True).astype(np.bool)

                # Mask time series using fmask result
                Y_fmask = np.delete(Y, idx_mask, axis=0)[:, valid]
                dates_fmask = dates[valid]

                # Apply multi-temporal mask (modified tmask)
                # Step 1. mask where green > 3 stddev from mean (fmasked) green
                multitemp1_fmask = np.where(
                    Y_fmask[1] < (np.mean(Y_fmask[1]) +
                                  np.std(Y_fmask[1]) * 3))
                dates_fmask = dates_fmask[multitemp1_fmask[0]]
                Y_fmask = Y_fmask[:, multitemp1_fmask[0]]
                # Step 2. mask where swir < 3 std devfrom mean (fmasked) SWIR
                multitemp2_fmask = np.where(
                    Y_fmask[4] > (np.mean(Y_fmask[4]) -
                                  np.std(Y_fmask[4]) * 3))
                dates_fmask = dates_fmask[multitemp2_fmask[0]]
                Y_fmask = Y_fmask[:, multitemp2_fmask[0]]

                # convert time from ordinal to dates
                dt_dates_fmask = np.array(
                    [dt.datetime.fromordinal(d) for d in dates_fmask])

                # Create dataframes for analysis
                # Step 1. reshape data
                shp_ = dt_dates_fmask.shape[0]
                dt_dates_fmask_csv = dt_dates_fmask.reshape(shp_, 1)
                Y_fmask_csv = np.transpose(Y_fmask)
                data_fmask = np.concatenate([dt_dates_fmask_csv, Y_fmask_csv],
                                            axis=1)
                # Step 2. create dataframe
                data_fmask_df = pd.DataFrame(data_fmask, columns=col_names)
                # convert reflectance to numeric type
                data_fmask_df[BAND_NAMES[0:10]] = data_fmask_df[
                    BAND_NAMES[0:10]].astype(int)

                # Group observations by year to generate annual TS
                year_group_fmask = data_fmask_df.groupby(
                    data_fmask_df.date.dt.year)
                # get years in time series
                years_fmask = np.asarray(year_group_fmask.groups.keys())
                years_fmask = years_fmask.astype(int)

                # TODO: FIX THIS!!!!!!!
                #import pdb; pdb.set_trace()
                month_group_fmask = data_fmask_df.groupby(
                    [data_fmask_df.date.dt.year,
                     data_fmask_df.date.dt.month]).max()
                month_groups = month_group_fmask.groupby(
                    month_group_fmask.date.dt.year)

                # Calculate number of observations
                nobs = year_group_fmask['tcb'].count()

                ### TC Brightness
                # Calculate mean annual TCB
                TCB_mean = year_group_fmask['tcb'].mean()
                if pct == False:
                    TCB_max_val = month_groups['tcb'].max()
                    TCB_max_idx = month_groups['tcb'].idxmax()
                    TCB_min_val = month_groups['tcb'].min()
                    TCB_min_idx = month_groups['tcb'].idxmin()
                else:
                    # percentile clip
                    TCB_max = year_group_fmask['tcb'].quantile([pct2])
                    TCB_min = year_group_fmask['tcb'].quantile([pct1])

                ### TC Greenness
                # Calculate mean annual TCG
                TCG_mean = year_group_fmask['tcg'].mean()
                if pct == False:
                    TCG_max_val = month_groups['tcg'].max()
                    TCG_max_idx = month_groups['tcg'].idxmax()
                    TCG_min_val = month_groups['tcg'].min()
                    TCG_min_idx = month_groups['tcg'].idxmin()
                else:
                    # percentile clip
                    TCG_max = year_group_fmask['tcg'].quantile([pct2])
                    TCG_min = year_group_fmask['tcg'].quantile([pct1])

                ### TC Wetness
                # Calculate mean annual TCW
                TCW_mean = year_group_fmask['tcw'].mean()
                if pct == False:
                    TCW_max_val = month_groups['tcw'].max()
                    TCW_max_idx = month_groups['tcw'].idxmax()
                    TCW_min_val = month_groups['tcw'].min()
                    TCW_min_idx = month_groups['tcw'].idxmin()
                else:
                    # percentile clip
                    TCW_max = year_group_fmask['tcw'].quantile([pct2])
                    TCW_min = year_group_fmask['tcw'].quantile([pct1])

                for index, year in enumerate(years):
                    if year in TCB_mean.index:
                        mean_TCB[py, px, index] = TCB_mean[year]
                        mean_TCG[py, px, index] = TCG_mean[year]
                        mean_TCW[py, px, index] = TCW_mean[year]

                        min_val_TCB[py, px, index] = TCB_min_val[year]
                        min_val_TCG[py, px, index] = TCG_min_val[year]
                        min_val_TCW[py, px, index] = TCW_min_val[year]

                        max_val_TCB[py, px, index] = TCB_max_val[year]
                        max_val_TCG[py, px, index] = TCG_max_val[year]
                        max_val_TCW[py, px, index] = TCW_max_val[year]

                        min_idx_TCB[py, px, index] = TCB_min_idx[year][1]
                        min_idx_TCG[py, px, index] = TCG_min_idx[year][1]
                        min_idx_TCW[py, px, index] = TCW_min_idx[year][1]

                        max_idx_TCB[py, px, index] = TCB_max_idx[year][1]
                        max_idx_TCG[py, px, index] = TCG_max_idx[year][1]
                        max_idx_TCW[py, px, index] = TCW_max_idx[year][1]

        run_time = time.time() - start_time
        print('Line {line} took {run_time}s to run'.format(line=py,
                                                           run_time=run_time))
        sys.stdout.flush()

    print('Statistics complete')
    print('Writing results to raster...')
    start_time = time.time()

    # Output map for each year
    in_ds = gdal.Open(example_img_fn, gdal.GA_ReadOnly)

    for index, year in enumerate(years):
        condition_fn = '/projectnb/landsat/users/valpasq/LCMS/dataviz/results/{WRS2}/mean/{WRS2}_ST-BGW_mean_{year}_{row1}-{row2}.tif'.format(
            WRS2=WRS2, year=year, row1=row1, row2=row2)
        out_driver = gdal.GetDriverByName("GTiff")
        out_ds = out_driver.Create(
            condition_fn,
            example_img.shape[1],  # x size
            example_img.shape[0],  # y size
            3,  # number of bands
            gdal.GDT_Int32)
        out_ds.SetProjection(in_ds.GetProjection())
        out_ds.SetGeoTransform(in_ds.GetGeoTransform())
        out_ds.GetRasterBand(1).WriteArray(mean_TCB[:, :, index])
        out_ds.GetRasterBand(1).SetNoDataValue(0)
        out_ds.GetRasterBand(1).SetDescription('Mean Annual TC Brightness')
        out_ds.GetRasterBand(2).WriteArray(mean_TCG[:, :, index])
        out_ds.GetRasterBand(2).SetNoDataValue(0)
        out_ds.GetRasterBand(2).SetDescription('Mean Annual TC Greenness')
        out_ds.GetRasterBand(3).WriteArray(mean_TCW[:, :, index])
        out_ds.GetRasterBand(3).SetNoDataValue(0)
        out_ds.GetRasterBand(3).SetDescription('Mean Annual TC Wetness')
        out_ds = None

        condition_fn = '/projectnb/landsat/users/valpasq/LCMS/dataviz/results/{WRS2}/min/{WRS2}_ST-BGW_min_val_{year}_{row1}-{row2}.tif'.format(
            WRS2=WRS2, year=year, row1=row1, row2=row2)
        out_driver = gdal.GetDriverByName("GTiff")
        out_ds = out_driver.Create(
            condition_fn,
            example_img.shape[1],  # x size
            example_img.shape[0],  # y size
            3,  # number of bands
            gdal.GDT_Int32)
        out_ds.SetProjection(in_ds.GetProjection())
        out_ds.SetGeoTransform(in_ds.GetGeoTransform())
        out_ds.GetRasterBand(1).WriteArray(min_val_TCB[:, :, index])
        out_ds.GetRasterBand(1).SetNoDataValue(0)
        out_ds.GetRasterBand(1).SetDescription('Minimum Annual TC Brightness')
        out_ds.GetRasterBand(2).WriteArray(min_val_TCG[:, :, index])
        out_ds.GetRasterBand(2).SetNoDataValue(0)
        out_ds.GetRasterBand(2).SetDescription('Minimum Annual TC Greenness')
        out_ds.GetRasterBand(3).WriteArray(min_val_TCW[:, :, index])
        out_ds.GetRasterBand(3).SetNoDataValue(0)
        out_ds.GetRasterBand(3).SetDescription('Minimum Annual TC Wetness')
        out_ds = None

        condition_fn = '/projectnb/landsat/users/valpasq/LCMS/dataviz/results/{WRS2}/max/{WRS2}_ST-BGW_max_val_{year}_{row1}-{row2}.tif'.format(
            WRS2=WRS2, year=year, row1=row1, row2=row2)
        out_driver = gdal.GetDriverByName("GTiff")
        out_ds = out_driver.Create(
            condition_fn,
            example_img.shape[1],  # x size
            example_img.shape[0],  # y size
            3,  # number of bands
            gdal.GDT_Int32)
        out_ds.SetProjection(in_ds.GetProjection())
        out_ds.SetGeoTransform(in_ds.GetGeoTransform())
        out_ds.GetRasterBand(1).WriteArray(max_val_TCB[:, :, index])
        out_ds.GetRasterBand(1).SetNoDataValue(0)
        out_ds.GetRasterBand(1).SetDescription('Maximum Annual TC Brightness')
        out_ds.GetRasterBand(2).WriteArray(max_val_TCG[:, :, index])
        out_ds.GetRasterBand(2).SetNoDataValue(0)
        out_ds.GetRasterBand(2).SetDescription('Maximum Annual TC Greenness')
        out_ds.GetRasterBand(3).WriteArray(max_val_TCW[:, :, index])
        out_ds.GetRasterBand(3).SetNoDataValue(0)
        out_ds.GetRasterBand(3).SetDescription('Maximum Annual TC Wetness')
        out_ds = None

        condition_fn = '/projectnb/landsat/users/valpasq/LCMS/dataviz/results/{WRS2}/min/{WRS2}_ST-BGW_min_mon_{year}_{row1}-{row2}.tif'.format(
            WRS2=WRS2, year=year, row1=row1, row2=row2)
        out_driver = gdal.GetDriverByName("GTiff")
        out_ds = out_driver.Create(
            condition_fn,
            example_img.shape[1],  # x size
            example_img.shape[0],  # y size
            3,  # number of bands
            gdal.GDT_Int32)
        out_ds.SetProjection(in_ds.GetProjection())
        out_ds.SetGeoTransform(in_ds.GetGeoTransform())
        out_ds.GetRasterBand(1).WriteArray(min_idx_TCB[:, :, index])
        out_ds.GetRasterBand(1).SetNoDataValue(0)
        out_ds.GetRasterBand(1).SetDescription('Minimum Annual TC Brightness')
        out_ds.GetRasterBand(2).WriteArray(min_idx_TCG[:, :, index])
        out_ds.GetRasterBand(2).SetNoDataValue(0)
        out_ds.GetRasterBand(2).SetDescription('Minimum Annual TC Greenness')
        out_ds.GetRasterBand(3).WriteArray(min_idx_TCW[:, :, index])
        out_ds.GetRasterBand(3).SetNoDataValue(0)
        out_ds.GetRasterBand(3).SetDescription('Minimum Annual TC Wetness')
        out_ds = None

        condition_fn = '/projectnb/landsat/users/valpasq/LCMS/dataviz/results/{WRS2}/max/{WRS2}_ST-BGW_max_mon_{year}_{row1}-{row2}.tif'.format(
            WRS2=WRS2, year=year, row1=row1, row2=row2)
        out_driver = gdal.GetDriverByName("GTiff")
        out_ds = out_driver.Create(
            condition_fn,
            example_img.shape[1],  # x size
            example_img.shape[0],  # y size
            3,  # number of bands
            gdal.GDT_Int32)
        out_ds.SetProjection(in_ds.GetProjection())
        out_ds.SetGeoTransform(in_ds.GetGeoTransform())
        out_ds.GetRasterBand(1).WriteArray(max_idx_TCB[:, :, index])
        out_ds.GetRasterBand(1).SetNoDataValue(0)
        out_ds.GetRasterBand(1).SetDescription('Maximum Annual TC Brightness')
        out_ds.GetRasterBand(2).WriteArray(max_idx_TCG[:, :, index])
        out_ds.GetRasterBand(2).SetNoDataValue(0)
        out_ds.GetRasterBand(2).SetDescription('Maximum Annual TC Greenness')
        out_ds.GetRasterBand(3).WriteArray(max_idx_TCW[:, :, index])
        out_ds.GetRasterBand(3).SetNoDataValue(0)
        out_ds.GetRasterBand(3).SetDescription('Maximum Annual TC Wetness')
        out_ds = None

    run_time = time.time() - start_time
    print('Rasters took {run_time}s to export'.format(run_time=run_time))
    sys.stdout.flush()
예제 #4
0
def pixel(ctx, config, px, py, band, plot, ylim, style, cmap,
          embed, seed, algo_kw):
    # Set seed
    np.random.seed(seed)
    # Convert band to index
    band -= 1

    # Get colormap
    if hasattr(palettable.colorbrewer, cmap):
        mpl_cmap = getattr(palettable.colorbrewer, cmap).mpl_colormap
    elif hasattr(palettable.cubehelix, cmap):
        mpl_cmap = getattr(palettable.cubehelix, cmap).mpl_colormap
    elif hasattr(palettable.wesanderson, cmap):
        mpl_cmap = getattr(palettable.wesanderson, cmap).mpl_colormap
    else:
        raise click.Abort('Cannot find specified colormap in `palettable`')

    # Parse config
    cfg = parse_config_file(config)

    # Apply algorithm overrides
    revalidate = False
    for kw in algo_kw:
        for cfg_key in cfg:
            if kw in cfg[cfg_key]:
                # Parse as YAML for type conversions used in config parser
                value = yaml.load(algo_kw[kw])

                print('Overriding cfg[%s][%s]=%s with %s' %
                      (cfg_key, kw, cfg[cfg_key][kw], value))
                cfg[cfg_key][kw] = value
                revalidate = True

    if revalidate:
        cfg = convert_config(cfg)

    # Locate and fetch attributes from data
    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
                              date_format=cfg['dataset']['date_format'])
    df['image_ID'] = get_image_IDs(df['filename'])

    # Setup X/Y
    kws = {'x': df['date']}
    kws.update(df.to_dict())
    X = patsy.dmatrix(cfg['YATSM']['design_matrix'], kws)
    design_info = X.design_info

    Y = read_pixel_timeseries(df['filename'], px, py)

    fit_indices = np.arange(cfg['dataset']['n_bands'])
    if cfg['dataset']['mask_band'] is not None:
        fit_indices = fit_indices[:-1]

    # Mask out of range data
    idx_mask = cfg['dataset']['mask_band'] - 1
    valid = cyprep.get_valid_mask(Y,
                                  cfg['dataset']['min_values'],
                                  cfg['dataset']['max_values']).astype(np.bool)
    valid *= np.in1d(Y[idx_mask, :], cfg['dataset']['mask_values'],
                     invert=True).astype(np.bool)

    # Apply mask
    Y = np.delete(Y, idx_mask, axis=0)[:, valid]
    X = X[valid, :]
    dates = np.array([dt.datetime.fromordinal(d) for d in df['date'][valid]])

    # Plot before fitting
    with plt.xkcd() if style == 'xkcd' else mpl.style.context(style):
        for _plot in plot:
            if _plot == 'TS':
                plot_TS(dates, Y[band, :])
            elif _plot == 'DOY':
                plot_DOY(dates, Y[band, :], mpl_cmap)
            elif _plot == 'VAL':
                plot_VAL(dates, Y[band, :], mpl_cmap)

            if ylim:
                plt.ylim(ylim)
            plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
            plt.ylabel('Band {b}'.format(b=band + 1))

            if embed and has_embed:
                IPython_embed()

            plt.tight_layout()
            plt.show()

    # Eliminate config parameters not algorithm and fit model
    yatsm = cfg['YATSM']['algorithm_cls'](lm=cfg['YATSM']['prediction_object'],
                                          **cfg[cfg['YATSM']['algorithm']])
    yatsm.px = px
    yatsm.py = py
    yatsm.fit(X, Y, np.asarray(df['date'][valid]))

    # Plot after predictions
    with plt.xkcd() if style == 'xkcd' else mpl.style.context(style):
        for _plot in plot:
            if _plot == 'TS':
                plot_TS(dates, Y[band, :])
            elif _plot == 'DOY':
                plot_DOY(dates, Y[band, :], mpl_cmap)
            elif _plot == 'VAL':
                plot_VAL(dates, Y[band, :], mpl_cmap)

            if ylim:
                plt.ylim(ylim)
            plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
            plt.ylabel('Band {b}'.format(b=band + 1))

            plot_results(band, cfg['YATSM'], yatsm, plot_type=_plot)

            if embed and has_embed:
                IPython_embed()

            plt.tight_layout()
            plt.show()
예제 #5
0
파일: cache.py 프로젝트: johanez/yatsm
def cache(ctx, config, job_number, total_jobs, update_pattern, interlace):
    cfg = parse_config_file(config)

    if not os.path.isdir(cfg['dataset']['cache_line_dir']):
        os.makedirs(cfg['dataset']['cache_line_dir'])

    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
                              cfg['dataset']['date_format'])
    df['image_IDs'] = get_image_IDs(df['filename'])

    nrow, ncol, nband, dtype = reader.get_image_attribute(df['filename'][0])

    # Determine lines to work on
    job_lines = distribute_jobs(job_number, total_jobs, nrow,
                                interlaced=interlace)
    logger.debug('Responsible for lines: {l}'.format(l=job_lines))

    # Determine file reader
    if cfg['dataset']['use_bip_reader']:
        logger.debug('Reading in data from disk using BIP reader')
        image_reader = reader.read_row_BIP
        image_reader_kwargs = {'size': (ncol, nband),
                               'dtype': dtype}
    else:
        logger.debug('Reading in data from disk using GDAL')
        image_reader = reader.read_row_GDAL
        image_reader_kwargs = {}

    # Attempt to update cache files
    previous_cache = None
    if update_pattern:
        previous_cache = fnmatch.filter(
            os.listdir(cfg['dataset']['cache_line_dir']), update_pattern)

        if not previous_cache:
            logger.warning('Could not find cache files to update with pattern '
                           '%s' % update_pattern)
        else:
            logger.debug('Found %s previously cached files to update' %
                         len(previous_cache))

    for job_line in job_lines:
        cache_filename = get_line_cache_name(cfg['dataset'], len(df),
                                             job_line, nband)
        logger.debug('Caching line {l} to {f}'.format(
            l=job_line, f=cache_filename))
        start_time = time.time()

        # Find matching cache file
        update = False
        if previous_cache:
            pattern = get_line_cache_pattern(job_line, nband, regex=False)

            potential = fnmatch.filter(previous_cache, pattern)

            if not potential:
                logger.info('Found zero previous cache files for '
                            'line {l}'.format(l=job_line))
            elif len(potential) > 1:
                logger.info('Found more than one previous cache file for '
                            'line {l}. Keeping first'.format(l=job_line))
                update = os.path.join(cfg['dataset']['cache_line_dir'],
                                      potential[0])
            else:
                update = os.path.join(cfg['dataset']['cache_line_dir'],
                                      potential[0])

            logger.info('Updating from cache file {f}'.format(f=update))

        if update:
            update_cache_file(df['filename'], df['image_IDs'],
                              update, cache_filename,
                              job_line, image_reader, image_reader_kwargs)
        else:
            if cfg['dataset']['use_bip_reader']:
                # Use BIP reader
                logger.debug('Reading in data from disk using BIP reader')
                Y = reader.read_row_BIP(df['filename'], job_line,
                                        (ncol, nband), dtype)
            else:
                # Read in data just using GDAL
                logger.debug('Reading in data from disk using GDAL')
                Y = reader.read_row_GDAL(df['filename'], job_line)
            write_cache_file(cache_filename, Y, df['image_IDs'])

        logger.debug('Took {s}s to cache the data'.format(
            s=round(time.time() - start_time, 2)))