示例#1
0
def test_register(capsys):
    try:
        p = ProgressBar()
        p.register()

        assert Callback.active

        get_threaded(dsk, 'e')
        check_bar_completed(capsys)

        p.unregister()

        assert not Callback.active
    finally:
        Callback.active.clear()
示例#2
0
def test_register(capsys):
    try:
        p = ProgressBar()
        p.register()

        assert _globals['callbacks']

        get(dsk, 'e')
        check_bar_completed(capsys)

        p.unregister()

        assert not _globals['callbacks']
    finally:
        _globals['callbacks'].clear()
示例#3
0
def test_register(capsys):
    try:
        p = ProgressBar()
        p.register()

        assert _globals['callbacks']

        get(dsk, 'e')
        out, err = capsys.readouterr()
        bar, percent, time = [i.strip() for i in out.split('\r')[-1].split('|')]
        assert bar == "[########################################]"
        assert percent == "100% Completed"

        p.unregister()

        assert not _globals['callbacks']
    finally:
        _globals['callbacks'].clear()
示例#4
0
def flow_allocation(n,
                    snapshots=None,
                    method='Average participation',
                    to_netcdf=None,
                    round_floats=8,
                    **kwargs):
    """
    Allocate or decompose the network flow with different methods.

    Available methods are 'Average participation' ('ap'), 'Marginal
    participation' ('mp'), 'Virtual injection pattern' ('vip'),
    'Zbus transmission' ('zbus').



    Parameters
    ----------
    n : pypsa.Network
        Network object with valid flow data.
    snapshots : string or pandas.DatetimeIndex
        (Subset of) snapshots of the network. If None (dafault) all snapshots
        are taken.
    per_bus : Boolean, default is False
        Whether to allocate the flow in an peer-to-peeer manner,
    method : string
        Type of the allocation method. Should be one of

            - 'Average participation'/'ap':
                Trace the active power flow from source to sink
                (or sink to source) using the principle of proportional
                sharing and calculate the partial flows on each line,
                or to each bus where the power goes to (or comes from).
            - 'Marginal participation'/'mp':
                Allocate line flows according to linear sensitvities
                of nodal power injection given by the changes in the
                power transfer distribution factors (PTDF)
            - 'Equivalent bilateral exchanges'/'ebe'
                Sequentially calculate the load flow induced by
                individual power sources in the network ignoring other
                sources and scaling down sinks.
            - 'Zbus transmission'/'zbus'

    Returns
    -------
    res : xr.Dataset
        Dataset with allocations depending on the method.
    """
    snapshots = check_snapshots(snapshots, n)
    n.calculate_dependent_values()
    if all(c.pnl.p0.empty for c in n.iterate_components(n.branch_components)):
        raise ValueError('Flows are not given by the network, '
                         'please solve the network flows first')

    if method not in _func_dict.keys():
        raise (ValueError('Method not implemented, please choose one out of'
                          f'{list(_func_dict.keys())}'))

    is_nonsequetial_func = _func_dict[method] in _non_sequential_funcs
    if isinstance(snapshots, (str, pd.Timestamp)) or is_nonsequetial_func:
        return _func_dict[method](n, snapshots, **kwargs)

    logger.info('Calculate allocations')

    func = _func_dict[method]
    res = [dask.delayed(func)(n, sn, **kwargs) for sn in snapshots]
    with ProgressBar():
        res = xr.concat(dask.compute(*res), dim=snapshots.rename('snapshot'))

    return res
def _down_sample(ltable,
                 rtable,
                 y_param,
                 show_progress=True,
                 verbose=False,
                 seed=None,
                 rem_puncs=True,
                 rem_stop_words=True,
                 n_ltable_chunks=-1,
                 n_rtable_chunks=-1):
    """
    Down sampling command implementation. We have reproduced the down sample command 
    because the input to the down sample command is the down sampled right table.   
    """

    if not isinstance(ltable, pd.DataFrame):
        logger.error('Input table A (ltable) is not of type pandas DataFrame')
        raise AssertionError(
            'Input table A (ltable) is not of type pandas DataFrame')

    if not isinstance(rtable, pd.DataFrame):
        logger.error('Input table B (rtable) is not of type pandas DataFrame')

        raise AssertionError(
            'Input table B (rtable) is not of type pandas DataFrame')

    if len(ltable) == 0 or len(rtable) == 0:
        logger.error('Size of the input table is 0')
        raise AssertionError('Size of the input table is 0')

    if y_param == 0:
        logger.error('y cannot be zero (3rd and 4th parameter of downsample)')
        raise AssertionError(
            'y_param cannot be zero (3rd and 4th parameter of downsample)')

    if seed is not None and not isinstance(seed, int):
        logger.error('Seed is not of type integer')
        raise AssertionError('Seed is not of type integer')

    validate_object_type(verbose, bool, 'Parameter verbose')
    validate_object_type(show_progress, bool, 'Parameter show_progress')
    validate_object_type(rem_stop_words, bool, 'Parameter rem_stop_words')
    validate_object_type(rem_puncs, bool, 'Parameter rem_puncs')
    validate_object_type(n_ltable_chunks, int, 'Parameter n_ltable_chunks')
    validate_object_type(n_rtable_chunks, int, 'Parameter n_rtable_chunks')

    # rtable_sampled = sample_right_table(rtable, size)
    rtable_sampled = rtable

    ltbl_str_cols = _get_str_cols_list(ltable)
    proj_ltable = ltable[ltable.columns[ltbl_str_cols]]

    if n_ltable_chunks == -1:
        n_ltable_chunks = multiprocessing.cpu_count()

    ltable_chunks = np.array_split(proj_ltable, n_ltable_chunks)
    preprocessed_tokenized_tbl = []
    start_row_id = 0
    for i in range(len(ltable_chunks)):
        result = delayed(process_tokenize_concat_strings)(ltable_chunks[i],
                                                          start_row_id,
                                                          rem_puncs,
                                                          rem_stop_words)
        preprocessed_tokenized_tbl.append(result)
        start_row_id += len(ltable_chunks[i])
    preprocessed_tokenized_tbl = delayed(wrap)(preprocessed_tokenized_tbl)
    if show_progress:
        with ProgressBar():
            logger.info('Preprocessing/tokenizing ltable')
            preprocessed_tokenized_tbl_vals = preprocessed_tokenized_tbl.compute(
                scheduler="processes", num_workers=multiprocessing.cpu_count())
    else:
        preprocessed_tokenized_tbl_vals = preprocessed_tokenized_tbl.compute(
            scheduler="processes", num_workers=multiprocessing.cpu_count())

    ltable_processed_dict = {}
    for i in range(len(preprocessed_tokenized_tbl_vals)):
        ltable_processed_dict.update(preprocessed_tokenized_tbl_vals[i])

    inverted_index = build_inverted_index(ltable_processed_dict)

    rtbl_str_cols = _get_str_cols_list(rtable_sampled)
    proj_rtable_sampled = rtable_sampled[rtable_sampled.columns[rtbl_str_cols]]

    if n_rtable_chunks == -1:
        n_rtable_chunks = multiprocessing.cpu_count()

    rtable_chunks = np.array_split(proj_rtable_sampled, n_rtable_chunks)
    probe_result = []

    for i in range(len(rtable_chunks)):
        result = delayed(probe)(rtable_chunks[i], y_param, len(proj_ltable),
                                inverted_index, rem_puncs, rem_stop_words,
                                seed)
        probe_result.append(result)

    probe_result = delayed(wrap)(probe_result)
    if show_progress:
        with ProgressBar():
            logger.info('Probing using rtable')
            probe_result = probe_result.compute(
                scheduler="processes", num_workers=multiprocessing.cpu_count())
    else:
        probe_result = probe_result.compute(
            scheduler="processes", num_workers=multiprocessing.cpu_count())

    probe_result = map(list, probe_result)
    l_tbl_indices = set(sum(probe_result, []))

    l_tbl_indices = list(l_tbl_indices)
    ltable_sampled = ltable.iloc[l_tbl_indices]

    # update catalog
    if cm.is_dfinfo_present(ltable):
        cm.copy_properties(ltable, ltable_sampled)

    if cm.is_dfinfo_present(rtable):
        cm.copy_properties(rtable, rtable_sampled)

    return ltable_sampled, rtable_sampled
import output, logging
import load_subreddit_castra, make_subreddit_castra
from dask.diagnostics import ProgressBar
from pprint import pprint
import pandas as pd
import dask.dataframe as dd

logging.basicConfig(level = logging.DEBUG, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.getLogger('requests').setLevel(logging.CRITICAL)
logger = logging.getLogger(__name__)

# Start a progress bar for all computations
pbar = ProgressBar()
pbar.register()

def test(file_name):
    """
    # Subsetting the dataframe
    a = df[df.link_id == 't3_36k7u4'].compute()

    # Get multiple columns from the dataframe
    b = df[['author', 'subreddit']].compute()

    # Groupby operations
    c = df.groupby(['link_id', 'author'])['ups'].count().compute()
    c = df.groupby(df.link_id).ups.mean().compute()
    c = df.groupby(df.link_id).score.count().compute()

    # Drop duplicates
    d = df.author.drop_duplicates().compute()
示例#7
0
    def generate_scripts(self):
        self.log_file.write("Reading from: \t" + self.cf.smx_path)
        self.log_file.write("Output folder: \t" + self.cf.output_path)
        self.log_file.write("SMX files:")
        print("Reading from: \t" + self.cf.smx_path)
        print("Output folder: \t" + self.cf.output_path)
        print("SMX files:")
        filtered_sources = []
        self.start_time = dt.datetime.now()
        try:
            smx_files = funcs.get_smx_files(self.cf.smx_path, self.smx_ext,
                                            self.sheets)
            for smx in smx_files:
                try:
                    self.count_smx = self.count_smx + 1
                    smx_file_path = self.cf.smx_path + "/" + smx
                    smx_file_name = os.path.splitext(smx)[0]
                    print("\t" + smx_file_name)
                    self.log_file.write("\t" + smx_file_name)
                    home_output_path = self.cf.output_path + "/" + smx_file_name + "/"

                    # self.parallel_remove_output_home_path.append(delayed(md.remove_folder)(home_output_path))
                    self.parallel_create_output_home_path.append(
                        delayed(md.create_folder)(home_output_path))

                    self.parallel_templates.append(
                        delayed(gcfr.gcfr)(self.cf, home_output_path))
                    ##################################### end of read_smx_folder ################################
                    if self.cf.source_names:
                        System_sht_filter = [[
                            'Source system name', self.cf.source_names
                        ]]
                    else:
                        System_sht_filter = None

                    System = funcs.read_excel(smx_file_path,
                                              sheet_name=self.System_sht)
                    teradata_sources = System[System['Source type'] ==
                                              'TERADATA']
                    teradata_sources = funcs.df_filter(teradata_sources,
                                                       System_sht_filter,
                                                       False)
                    self.count_sources = self.count_sources + len(
                        teradata_sources.index)

                    Supplements = delayed(funcs.read_excel)(
                        smx_file_path, sheet_name=self.Supplements_sht)
                    Column_mapping = delayed(funcs.read_excel)(
                        smx_file_path, sheet_name=self.Column_mapping_sht)
                    BMAP_values = delayed(funcs.read_excel)(
                        smx_file_path, sheet_name=self.BMAP_values_sht)
                    BMAP = delayed(funcs.read_excel)(smx_file_path,
                                                     sheet_name=self.BMAP_sht)
                    BKEY = delayed(funcs.read_excel)(smx_file_path,
                                                     sheet_name=self.BKEY_sht)
                    Core_tables = delayed(funcs.read_excel)(
                        smx_file_path, sheet_name=self.Core_tables_sht)
                    Core_tables = delayed(funcs.rename_sheet_reserved_word)(
                        Core_tables, Supplements, 'TERADATA',
                        ['Column name', 'Table name'])
                    ##################################### end of read_smx_sheet ################################

                    for system_index, system_row in teradata_sources.iterrows(
                    ):
                        try:
                            Loading_Type = system_row['Loading type'].upper()
                            if Loading_Type != "":
                                source_name = system_row['Source system name']
                                filtered_sources.append(source_name)

                                source_name_filter = [[
                                    'Source', [source_name]
                                ]]
                                core_layer_filter = [['Layer', ["CORE"]]]
                                stg_layer_filter = [['Layer', ["STG"]]]
                                stg_source_name_filter = [[
                                    'Source system name', [source_name]
                                ]]

                                Table_mapping = delayed(funcs.read_excel)(
                                    smx_file_path, self.Table_mapping_sht,
                                    source_name_filter)

                                core_Table_mapping = delayed(funcs.df_filter)(
                                    Table_mapping, core_layer_filter, False)
                                stg_Table_mapping = delayed(funcs.df_filter)(
                                    Table_mapping, stg_layer_filter, False)

                                STG_tables = delayed(funcs.read_excel)(
                                    smx_file_path, self.STG_tables_sht,
                                    stg_source_name_filter)
                                STG_tables = delayed(
                                    funcs.rename_sheet_reserved_word)(
                                        STG_tables, Supplements, 'TERADATA',
                                        ['Column name', 'Table name'])

                                source_output_path = home_output_path + "/" + Loading_Type + "/" + source_name

                                self.parallel_create_output_source_path.append(
                                    delayed(
                                        md.create_folder)(source_output_path))

                                self.parallel_templates.append(
                                    delayed(D000.d000)(self.cf,
                                                       source_output_path,
                                                       source_name,
                                                       core_Table_mapping,
                                                       STG_tables, BKEY))
                                self.parallel_templates.append(
                                    delayed(D001.d001)(self.cf,
                                                       source_output_path,
                                                       source_name,
                                                       STG_tables))
                                self.parallel_templates.append(
                                    delayed(D002.d002)(self.cf,
                                                       source_output_path,
                                                       Core_tables,
                                                       core_Table_mapping))
                                self.parallel_templates.append(
                                    delayed(D003.d003)(self.cf,
                                                       source_output_path,
                                                       BMAP_values, BMAP))

                                self.parallel_templates.append(
                                    delayed(D110.d110)(self.cf,
                                                       source_output_path,
                                                       stg_Table_mapping,
                                                       STG_tables,
                                                       Loading_Type))

                                self.parallel_templates.append(
                                    delayed(D200.d200)(self.cf,
                                                       source_output_path,
                                                       STG_tables,
                                                       Loading_Type))
                                self.parallel_templates.append(
                                    delayed(D210.d210)(self.cf,
                                                       source_output_path,
                                                       STG_tables,
                                                       Loading_Type))

                                self.parallel_templates.append(
                                    delayed(D300.d300)(self.cf,
                                                       source_output_path,
                                                       STG_tables, BKEY))
                                self.parallel_templates.append(
                                    delayed(D320.d320)(self.cf,
                                                       source_output_path,
                                                       STG_tables, BKEY))
                                self.parallel_templates.append(
                                    delayed(D330.d330)(self.cf,
                                                       source_output_path,
                                                       STG_tables, BKEY))
                                self.parallel_templates.append(
                                    delayed(D340.d340)(self.cf,
                                                       source_output_path,
                                                       STG_tables, BKEY))

                                # self.parallel_templates.append(delayed(D400.d400)(self.cf, source_output_path, STG_tables))
                                # self.parallel_templates.append(delayed(D410.d410)(self.cf, source_output_path, STG_tables))
                                # self.parallel_templates.append(delayed(D415.d415)(self.cf, source_output_path, STG_tables))
                                self.parallel_templates.append(
                                    delayed(D420.d420)(self.cf,
                                                       source_output_path,
                                                       STG_tables, BKEY, BMAP,
                                                       Loading_Type))

                                self.parallel_templates.append(
                                    delayed(D600.d600)(self.cf,
                                                       source_output_path,
                                                       core_Table_mapping,
                                                       Core_tables))
                                self.parallel_templates.append(
                                    delayed(D607.d607)(self.cf,
                                                       source_output_path,
                                                       Core_tables,
                                                       BMAP_values))
                                self.parallel_templates.append(
                                    delayed(D608.d608)(self.cf,
                                                       source_output_path,
                                                       Core_tables,
                                                       BMAP_values))
                                self.parallel_templates.append(
                                    delayed(D610.d610)(self.cf,
                                                       source_output_path,
                                                       core_Table_mapping))
                                self.parallel_templates.append(
                                    delayed(D615.d615)(self.cf,
                                                       source_output_path,
                                                       Core_tables))
                                self.parallel_templates.append(
                                    delayed(D620.d620)(self.cf,
                                                       source_output_path,
                                                       core_Table_mapping,
                                                       Column_mapping,
                                                       Core_tables,
                                                       Loading_Type))
                                self.parallel_templates.append(
                                    delayed(D630.d630)(self.cf,
                                                       source_output_path,
                                                       core_Table_mapping))
                                self.parallel_templates.append(
                                    delayed(D640.d640)(self.cf,
                                                       source_output_path,
                                                       source_name,
                                                       core_Table_mapping))

                                self.parallel_templates.append(
                                    delayed(
                                        testing_script_01.source_testing_script
                                    )(self.cf, source_output_path, source_name,
                                      core_Table_mapping, Column_mapping,
                                      STG_tables, BKEY))
                                self.parallel_templates.append(
                                    delayed(
                                        testing_script_02.source_testing_script
                                    )(self.cf, source_output_path, source_name,
                                      Table_mapping, Core_tables))

                        except Exception as e_source:
                            # print(error)

                            # log: smx_file_name, source_name
                            print(system_row.to_dict())
                            funcs.SMXFilesLogError(
                                self.cf.output_path, smx,
                                str(system_row.to_dict()),
                                traceback.format_exc()).log_error()
                            self.count_sources = self.count_sources - 1

                except Exception as e_smx_file:
                    # print(error)
                    funcs.SMXFilesLogError(self.cf.output_path, smx, None,
                                           traceback.format_exc()).log_error()
                    self.count_smx = self.count_smx - 1

        except Exception as e1:
            # print(error)
            # traceback.print_exc()
            funcs.SMXFilesLogError(self.cf.output_path, None, None,
                                   traceback.format_exc()).log_error()

        if len(self.parallel_templates) > 0:
            sources = funcs.list_to_string(filtered_sources, ', ')
            print("Sources:", sources)
            self.log_file.write("Sources:" + sources)
            scheduler_value = 'processes' if self.cf.read_sheets_parallel == 1 else ''
            with config.set(scheduler=scheduler_value):
                # compute(*self.parallel_remove_output_home_path)
                compute(*self.parallel_create_output_home_path)
                compute(*self.parallel_create_output_source_path)

                with ProgressBar():
                    smx_files = " smx files" if self.count_smx > 1 else " smx file"
                    smx_file_sources = " sources" if self.count_sources > 1 else " source"
                    print("Start generating " +
                          str(len(self.parallel_templates)) + " script for " +
                          str(self.count_sources) + smx_file_sources +
                          " from " + str(self.count_smx) + smx_files)
                    compute(*self.parallel_templates)
                    self.log_file.write(
                        str(len(self.parallel_templates)) +
                        " script generated for " + str(self.count_sources) +
                        smx_file_sources + " from " + str(self.count_smx) +
                        smx_files)
                    self.elapsed_time = dt.datetime.now() - self.start_time
                    self.log_file.write("Elapsed Time: " +
                                        str(self.elapsed_time))
            self.error_message = ""
            os.startfile(self.cf.output_path)
        else:
            self.error_message = "No SMX Files Found!"

        self.log_file.close()
示例#8
0
import numpy as np
import pandas as pd
import logging
import re
from scipy.stats import pearsonr
from subprocess import check_output

from matplotlib import pyplot as plt
import seaborn as sns

sns.set(font_scale=1.5, style="whitegrid")

import dask.bag as db
from dask.diagnostics import ProgressBar

ProgressBar().register()

# In[2]:

logger = logging.getLogger(__name__)
logging.basicConfig(format='[%(asctime)s - %(name)s] %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.DEBUG,
                    handlers=[logging.StreamHandler()])

# In[3]:


def import_gdrive_sheet(gdrive_key, sheet_id):
    run_spreadsheet = pd.read_csv(
        "https://docs.google.com/spreadsheet/ccc?key=" + gdrive_key +
示例#9
0
    def rescale_intensity(
        self,
        relative: bool = False,
        in_range: Union[None, Tuple[int, int], Tuple[float, float]] = None,
        out_range: Union[None, Tuple[int, int], Tuple[float, float]] = None,
        dtype_out: Union[None, np.dtype, Tuple[int, int], Tuple[float, float]] = None,
        percentiles: Union[None, Tuple[int, int], Tuple[float, float]] = None,
    ):
        """Rescale image intensities inplace.

        Output min./max. intensity is determined from `out_range` or the
        data type range of the :class:`numpy.dtype` passed to
        `dtype_out` if `out_range` is None.

        This method is based on
        :func:`skimage.exposure.rescale_intensity`.

        Parameters
        ----------
        relative
            Whether to keep relative intensities between images (default
            is False). If True, `in_range` must be None, because
            `in_range` is in this case set to the global min./max.
            intensity.
        in_range
            Min./max. intensity of input images. If None (default),
            `in_range` is set to pattern min./max intensity. Contrast
            stretching is performed when `in_range` is set to a narrower
            intensity range than the input patterns. Must be None if
            `relative` is True or `percentiles` are passed.
        out_range
            Min./max. intensity of output images. If None (default),
            `out_range` is set to `dtype_out` min./max according to
            `skimage.util.dtype.dtype_range`.
        dtype_out
            Data type of rescaled images, default is input images' data
            type.
        percentiles
            Disregard intensities outside these percentiles. Calculated
            per image. Must be None if `in_range` or `relative` is
            passed. Default is None.

        See Also
        --------
        kikuchipy.pattern.rescale_intensity,
        :func:`skimage.exposure.rescale_intensity`

        Examples
        --------
        >>> import numpy as np
        >>> import kikuchipy as kp
        >>> s = kp.data.nickel_ebsd_small()

        Image intensities are stretched to fill the available grey
        levels in the input images' data type range or any
        :class:`numpy.dtype` range passed to `dtype_out`, either
        keeping relative intensities between images or not:

        >>> print(
        ...     s.data.dtype, s.data.min(), s.data.max(),
        ...     s.inav[0, 0].data.min(), s.inav[0, 0].data.max()
        ... )
        uint8 23 246 26 245
        >>> s2 = s.deepcopy()
        >>> s.rescale_intensity(dtype_out=np.uint16)  # doctest: +SKIP
        >>> print(
        ...     s.data.dtype, s.data.min(), s.data.max(),
        ...     s.inav[0, 0].data.min(), s.inav[0, 0].data.max()
        ... )  # doctest: +SKIP
        uint16 0 65535 0 65535
        >>> s2.rescale_intensity(relative=True)  # doctest: +SKIP
        >>> print(
        ...     s2.data.dtype, s2.data.min(), s2.data.max(),
        ...     s2.inav[0, 0].data.min(), s2.inav[0, 0].data.max()
        ... )  # doctest: +SKIP
        uint8 0 255 3 253

        Contrast stretching can be performed by passing percentiles:

        >>> s.rescale_intensity(percentiles=(1, 99))  # doctest: +SKIP

        Here, the darkest and brightest pixels within the 1% percentile
        are set to the ends of the data type range, e.g. 0 and 255
        respectively for images of ``uint8`` data type.

        Notes
        -----
        Rescaling RGB images is not possible. Use RGB channel
        normalization when creating the image instead.
        """
        if self.data.dtype in rgb_dtypes.values():
            raise NotImplementedError(
                "Use RGB channel normalization when creating the image instead."
            )

        # Determine min./max. intensity of input image to rescale to
        if in_range is not None and percentiles is not None:
            raise ValueError("'percentiles' must be None if 'in_range' is not None.")
        elif relative is True and in_range is not None:
            raise ValueError("'in_range' must be None if 'relative' is True.")
        elif relative:  # Scale relative to min./max. intensity in images
            in_range = (self.data.min(), self.data.max())

        if dtype_out is None:
            dtype_out = self.data.dtype.type

        if out_range is None:
            dtype_out_pass = dtype_out
            if isinstance(dtype_out, np.dtype):
                dtype_out_pass = dtype_out.type
            out_range = dtype_range[dtype_out_pass]

        # Create dask array of signal images and do processing on this
        dask_array = get_dask_array(signal=self)

        # Rescale images
        rescaled_images = dask_array.map_blocks(
            func=chunk.rescale_intensity,
            in_range=in_range,
            out_range=out_range,
            dtype_out=dtype_out,
            percentiles=percentiles,
            dtype=dtype_out,
        )

        # Overwrite signal images
        if not self._lazy:
            with ProgressBar():
                if self.data.dtype != rescaled_images.dtype:
                    self.change_dtype(dtype_out)
                print("Rescaling the image intensities:", file=sys.stdout)
                rescaled_images.store(self.data, compute=True)
        else:
            self.data = rescaled_images
示例#10
0
def main():
    parser = argparse.ArgumentParser(
        description='Add multiscale levels to an existing n5')

    parser.add_argument('-i', '--input', dest='input_path', type=str, required=True, \
        help='Path to the directory containing the n5 volume')

    parser.add_argument('-d', '--data_set', dest='data_set', type=str, default="", \
        help='Path to data set (default empty, so /s0 is assumed to exist at the root)')

    parser.add_argument('-f', '--downsampling_factors', dest='downsampling_factors', type=str, default="2,2,2", \
        help='Downsampling factors for each dimension (default "2,2,2")')

    parser.add_argument('-p', '--pixel_res', dest='pixel_res', type=str, \
        help='Pixel resolution for each dimension "2.0,2.0,2.0" (default None) - required for Neuroglancer')

    parser.add_argument('-u', '--pixel_res_units', dest='pixel_res_units', type=str, default="nm", \
        help='Measurement unit for --pixel_res (default "nm") - required for Neuroglancer')

    parser.add_argument('--distributed', dest='distributed', action='store_true', \
        help='Run with distributed scheduler (default)')
    parser.set_defaults(distributed=False)

    parser.add_argument('--workers', dest='workers', type=int, default=20, \
        help='If --distributed is set, this specifies the number of workers (default 20)')

    parser.add_argument('--dashboard', dest='dashboard', action='store_true', \
        help='If --distributed is set, this runs a web-based dashboard on port 8787')
    parser.set_defaults(dashboard=False)

    parser.add_argument('--metadata-only', dest='metadata_only', action='store_true', \
        help='Only fix metadata on an existing multiscale pyramid')
    parser.set_defaults(metadata_only=False)

    args = parser.parse_args()

    if args.distributed:
        dashboard_address = None
        if args.dashboard:
            dashboard_address = ":8787"
            print(f"Starting dashboard on {dashboard_address}")

        from dask.distributed import Client
        client = Client(processes=True, n_workers=args.workers, \
            threads_per_worker=1, dashboard_address=dashboard_address)

    else:
        from dask.diagnostics import ProgressBar
        pbar = ProgressBar()
        pbar.register()

    downsampling_factors = [
        int(c) for c in args.downsampling_factors.split(',')
    ]

    pixel_res = None
    if args.pixel_res:
        pixel_res = [float(c) for c in args.pixel_res.split(',')]

    if not args.metadata_only:
        add_multiscale(args.input_path,
                       args.data_set,
                       downsampling_factors=downsampling_factors)

    add_metadata(args.input_path,
                 downsampling_factors=downsampling_factors,
                 pixel_res=pixel_res,
                 pixel_res_units=args.pixel_res_units)
示例#11
0
class MetSim(object):
    """
    MetSim handles the distribution of jobs that write to a common file
    by launching muliple processes and queueing up their writeback so that
    work can be done while IO is happening.
    """

    # Class variables
    methods = {'mtclim': mtclim, 'passthrough': passthrough}
    params = {
        "period_ending": False,
        "is_worker": False,
        "method": 'mtclim',
        "domain": '',
        "state": '',
        "out_dir": '',
        "out_prefix": 'forcing',
        "start": 'forcing',
        "stop": 'forcing',
        "forcing_fmt": 'netcdf',
        "time_step": -1,
        "calendar": 'standard',
        "prec_type": 'uniform',
        "out_precision": 'f4',
        "verbose": 0,
        "sw_prec_thresh": 0.0,
        "utc_offset": False,
        "lw_cloud": 'cloud_deardorff',
        "lw_type": 'prata',
        "prec_type": 'uniform',
        "tdew_tol": 1e-6,
        "tmax_daylength_fraction": 0.67,
        "rain_scalar": 0.75,
        "tday_coef": 0.45,
        "lapse_rate": 0.0065,
        "out_vars": {n: available_outputs[n]
                     for n in default_outputs},
        "out_freq": None,
        "chunks": NO_SLICE,
        "scheduler": 'distributed',
        "num_workers": 1,
    }

    def __init__(self, params: dict, domain_slice=NO_SLICE):
        """
        Constructor
        """
        self._domain = None
        self._met_data = None
        self._state = None
        self._client = None
        self._domain_slice = domain_slice
        self.progress_bar = ProgressBar()
        self.params.update(params)
        logging.captureWarnings(True)
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(self.params['verbose'])

        formatter = logging.Formatter(' - '.join(
            ['%asctime)s', '%(name)s', '%(levelname)s', '%(message)s']))
        ch = logging.StreamHandler(sys.stdout)
        ch.setFormatter(formatter)
        ch.setLevel(self.params['verbose'])
        # set global dask scheduler
        if domain_slice is NO_SLICE:
            if self.params['scheduler'] in DASK_CORE_SCHEDULERS:
                dask.config.set(scheduler=self.params['scheduler'])
            else:
                from distributed import Client, progress
                if 'distributed' == self.params['scheduler']:
                    self._client = Client(n_workers=self.params['num_workers'],
                                          threads_per_worker=1)
                    if self.params['verbose'] == logging.DEBUG:
                        self.progress_bar = progress
                elif os.path.isfile(self.params['scheduler']):
                    self._client = Client(
                        scheduler_file=self.params['scheduler'])
                else:
                    self._client = Client(self.params['scheduler'])
        else:
            dask.config.set(scheduler=self.params['scheduler'])

        # Set up logging
        # If in verbose mode set up the progress bar
        if self.params['verbose'] == logging.DEBUG:
            if 'distributed' != self.params['scheduler']:
                self.progress_bar.register()
                self.progress_bar = lambda x: x
        else:
            # If not in verbose mode, create a dummy function
            self.progress_bar = lambda x: x
        # Create time vector(s)
        self._times = self._get_output_times(
            freq=self.params['out_freq'],
            period_ending=self.params['period_ending'])

        self._update_unit_attrs(self.params['out_vars'])

    def _update_unit_attrs(self, out_vars):
        for k, v in out_vars.items():
            if 'units' in v.keys():
                if v['units'] in converters[k].keys():
                    attrs[k]['units'] = v['units']
                else:
                    self.logger.warn(
                        f'Could not find unit conversion for {k} to {v["units"]}!'
                        f' We will use the default units of'
                        f' {available_outputs[k]["units"]} instead.')
                    v['units'] = available_outputs[k]['units']
            else:
                v['units'] = available_outputs[k]['units']

    def _validate_force_times(self, force_times):
        for p, i in [('start', 0), ('stop', -1)]:
            # infer times from force_times
            if isinstance(self.params[p], str):
                if self.params[p] == 'forcing':
                    self.params[p] = pd.Timestamp(
                        force_times.values[i]).to_pydatetime()
                elif '/' in self.params[p]:
                    year, month, day = map(int, self.params[p].split('/'))
                    self.params[p] = pd.datetime(year, month, day)
                else:
                    self.params[p] = pd.to_datetime(self.params[p])

        # update calendar from input data (fall back to params version)
        self.params['calendar'] = self.met_data['time'].encoding.get(
            'calendar', self.params['calendar'])

        assert self.params['start'] >= pd.Timestamp(
            force_times.values[0]).to_pydatetime()
        assert self.params['stop'] <= pd.Timestamp(
            force_times.values[-1]).to_pydatetime()

        self.params['state_start'] = (self.params['start'] -
                                      pd.Timedelta("90 days"))
        self.params['state_stop'] = (self.params['start'] -
                                     pd.Timedelta("1 days"))
        if self.params['utc_offset']:
            attrs['time'] = {
                'units': DEFAULT_TIME_UNITS,
                'long_name': 'UTC time',
                'standard_name': 'utc_time'
            }
        else:
            attrs['time'] = {
                'units': DEFAULT_TIME_UNITS,
                'long_name': 'local time at grid location',
                'standard_name': 'local_time'
            }

    def convert_monthly_param(self, name):
        self.met_data[name] = self.met_data['prec'].copy()
        months = self.met_data['time'].dt.month
        for m in range(12):
            param = self.domain[name].sel(month=m)
            locations = {'time': self.met_data['time'].isel(time=months == m)}
            self.met_data[name].loc[locations] = param

    @property
    def domain(self):
        if self._domain is None:
            self._domain = io.read_domain(
                self.params).isel(**self._domain_slice)
        return self._domain

    @property
    def met_data(self):
        if self._met_data is None:
            self._met_data = io.read_met_data(self.params, self.domain)
            self._met_data['elev'] = self.domain['elev']
            self._met_data['lat'] = self.domain['lat']
            self._met_data['lon'] = self.domain['lon']

            # process constant_vars
            constant_vars = self.params.get('constant_vars', None)
            if constant_vars:
                da_template = self._met_data[list(self._met_data)[0]]
                for var in constant_vars.keys():
                    self._met_data[var] = xr.full_like(
                        da_template, float(constant_vars[var]))

            self._validate_force_times(force_times=self._met_data['time'])
        return self._met_data

    @property
    def state(self):
        if self._state is None:
            self._state = io.read_state(self.params, self.domain)
            self._aggregate_state()
        return self._state

    @property
    def slices(self):
        if not self.params['chunks']:
            return [{d: slice(None) for d in self.domain[['mask']].dims}]

        return chunk_domain(self.params['chunks'], self.domain[['mask']].dims)

    def open_output(self):
        filenames = [self._get_output_filename(times) for times in self._times]
        return xr.open_mfdataset(filenames)

    def run(self):
        self._validate_setup()
        write_locks = {}
        for times in self._times:
            filename = self._get_output_filename(times)
            self.setup_netcdf_output(filename, times)
            write_locks[filename] = combine_locks(
                [NETCDFC_LOCK, get_write_lock(filename)])
        self.logger.info('Starting {} chunks...'.format(len(self.slices)))

        delayed_objs = [
            wrap_run_slice(self.params, write_locks, dslice)
            for dslice in self.slices
        ]
        persisted = dask.persist(delayed_objs,
                                 num_workers=self.params['num_workers'])
        self.progress_bar(persisted)
        dask.compute(persisted)
        self.logger.info('Cleaning up...')
        try:
            self._client.cluster.close()
            self._client.close()
            if self.params['verbose'] == logging.DEBUG:
                print()
                print('closed dask cluster/client')
        except Exception:
            pass

    def load_inputs(self, close=True):
        self._domain = self.domain.load()
        self._met_data = self.met_data.load()
        self._state = self.state.load()
        if close:
            self._domain.close()
            self._met_data.close()
            self._state.close()

    def setup_netcdf_output(self, filename, times):
        '''setup a single netcdf file'''
        with Dataset(filename, mode="w") as ncout:
            # dims
            dim_sizes = (None, ) + self.domain['mask'].shape
            var_dims = ('time', ) + self.domain['mask'].dims
            chunksizes = [len(times)]
            for d, s in zip(var_dims[1:], dim_sizes[1:]):
                c = int(self.params['chunks'].get(d, s))
                if c <= s:
                    chunksizes.append(c)
                else:
                    chunksizes.append(s)
            create_kwargs = {'chunksizes': chunksizes}
            for d, size in zip(var_dims, dim_sizes):
                ncout.createDimension(d, size)
            # vars
            for varname, varconf in self.params['out_vars'].items():
                ncout.createVariable(varconf['out_name'],
                                     self.params['out_precision'], var_dims,
                                     **create_kwargs)

            # add metadata and coordinate variables (time/lat/lon)
            time_var = ncout.createVariable('time', 'i4', ('time', ))
            time_var.calendar = self.params['calendar']
            time_var[:] = date2num(times.to_pydatetime(),
                                   units=attrs['time'].get(
                                       'units', DEFAULT_TIME_UNITS),
                                   calendar=time_var.calendar)

            dtype_map = {
                'float64': 'f8',
                'float32': 'f4',
                'int64': 'i8',
                'int32': 'i4'
            }
            for dim in self.domain['mask'].dims:
                dim_vals = self.domain[dim].values
                dim_dtype = dtype_map.get(str(dim_vals.dtype),
                                          self.params['out_precision'])
                dim_var = ncout.createVariable(dim, dim_dtype, (dim, ))
                dim_var[:] = dim_vals

            # parameters to not record in the metadata
            skip_params = [
                'elev',
                'lat',
                'lon',
                'is_worker',
                'out_vars',
                'forcing_vars',
                'domain_vars',
                'state_vars',
                'constant_vars',
                'references',
                'verbose',
                'num_workers',
            ]
            for k, v in self.params.items():
                if k in skip_params:
                    continue
                # Need to convert some parameters to strings
                if k in ['start', 'stop', 'utc_offset', 'period_ending']:
                    v = str(v)
                elif k in ['state_start', 'state_stop', 'out_freq']:
                    # skip
                    continue
                # Don't include complex types
                if isinstance(v, dict):
                    v = json.dumps(v)
                elif not isinstance(v, str) and isinstance(v, Iterable):
                    v = ', '.join(v)

                if isinstance(v, str):
                    v = v.replace("'", "").replace('"', "")
                attrs['_global'][k] = v

            # set global attrs
            for key, val in attrs['_global'].items():
                setattr(ncout, key, val)

            # set variable attrs
            for key, value in attrs.get('time', {}).items():
                setattr(ncout.variables['time'], key, value)
            for varname, varconf in self.params['out_vars'].items():
                outname = varconf['out_name']
                for key, val in attrs.get(varname, {}).items():
                    setattr(ncout.variables[outname], key, val)

    def write_chunk(self, locks=None):
        '''write data from a single chunk'''
        if not len(self.params['out_vars']):
            return
        for times in self._times:
            filename = self._get_output_filename(times)
            lock = locks.get(filename, DummyLock())
            time_slice = slice(times[0], times[-1])
            with lock:
                with Dataset(filename, mode="r+") as ncout:
                    for varname, varconf in self.params['out_vars'].items():
                        outname = varconf['out_name']
                        dims = ncout.variables[outname].dimensions[1:]
                        write_slice = ((slice(None), ) +
                                       tuple(self._domain_slice[d]
                                             for d in dims))
                        ncout.variables[outname][write_slice] = (
                            self.output[varname].sel(time=time_slice).values)

    def run_slice(self):
        """
        Run a single slice of
        """
        self._validate_setup()
        self.disagg = int(self.params['time_step']) < cnst.MIN_PER_DAY
        self.method = MetSim.methods[self.params['method']]
        self.setup_output()
        times = self.met_data['time']
        params = self.params.copy()
        # transform input parameters to floating point values
        params['sw_prec_thresh'] = float(params['sw_prec_thresh'])
        params['rain_scalar'] = float(params['rain_scalar'])
        params['tdew_tol'] = float(params['tdew_tol'])
        params['tmax_daylength_fraction'] = float(
            params['tmax_daylength_fraction'])
        params['tday_coef'] = float(params['tday_coef'])
        params['tmax_daylength_fraction'] = float(
            params['tmax_daylength_fraction'])
        params['lapse_rate'] = float(params['lapse_rate'])
        if self.params['prec_type'].upper() in ['TRIANGLE', 'MIX']:
            self.convert_monthly_param('dur')
            self.convert_monthly_param('t_pk')
        for index, mask_val in np.ndenumerate(self.domain['mask'].values):
            if mask_val > 0:
                locs = {d: i for d, i in zip(self.domain['mask'].dims, index)}
            else:
                continue
            df, state = wrap_run_cell(self.method.run, params,
                                      self.met_data.isel(**locs),
                                      self.state.isel(**locs), self.disagg,
                                      times)

            # Cut the returned data down to the correct time index
            # and do any required unit conversions
            for varname in self.params['out_vars']:
                desired_units = self.params['out_vars'][varname]['units']
                out_vals = converters[varname][desired_units](
                    df[varname].values, int(self.params['time_step']))
                self.output[varname][locs] = out_vals

    def _unpack_state(self, result: pd.DataFrame, locs: dict):
        """Put restart values in the state dataset"""
        # We concatenate with the old state values in case we don't
        # have 90 new days to use
        tmin = np.concatenate((self.state['t_min'].isel(**locs).values[:],
                               result['t_min'].values))
        tmax = np.concatenate((self.state['t_max'].isel(**locs).values[:],
                               result['t_max'].values))
        prec = np.concatenate(
            (self.state['prec'].isel(**locs).values[:], result['prec'].values))
        self.state['t_min'].isel(**locs).values[:] = tmin[-90:]
        self.state['t_max'].isel(**locs).values[:] = tmax[-90:]
        self.state['prec'].isel(**locs).values[:] = prec[-90:]
        state_start = result.index[-1] - pd.Timedelta('89 days')
        self.state['time'].values = date_range(
            state_start, result.index[-1], calendar=self.params['calendar'])

    def _get_output_times(self, freq=None, period_ending=False):
        """
        Generate chunked time vectors

        Parameters
        ----------
        freq:
            Output frequency. Given as a Pandas timegrouper string.
            If not given, the entire timeseries will be used.
        period_ending:
            Flag to specify if output timesteps should be period-
            ending. Default is period-beginning

        Returns
        -------
        times:
            A list of timeseries which represent each of times that
            output files will be created for.
        """
        prototype = self.met_data
        self.disagg = int(self.params['time_step']) < cnst.MIN_PER_DAY

        if self.disagg:
            delta = pd.Timedelta('1 days') - pd.Timedelta('{} minutes'.format(
                self.params['time_step']))
        else:
            delta = pd.Timedelta('0 days')
        if period_ending:
            offset = pd.Timedelta('{} minutes'.format(
                self.params['time_step']))
        else:
            offset = pd.Timedelta('0 minutes')

        start = pd.Timestamp(prototype['time'].values[0]).to_pydatetime()
        stop = pd.Timestamp(prototype['time'].values[-1]).to_pydatetime()
        times = date_range(start + offset,
                           stop + offset + delta,
                           freq="{}T".format(self.params['time_step']),
                           calendar=self.params['calendar'])

        if freq is None or freq == '':
            times = [times]
        else:
            dummy = pd.Series(np.arange(len(times)), index=times)
            grouper = pd.Grouper(freq=freq)
            times = [t.index for k, t in dummy.groupby(grouper)]
        return times

    def _get_output_filename(self, times):
        suffix = self.get_nc_output_suffix(times)
        fname = '{}_{}.nc'.format(self.params['out_prefix'], suffix)
        output_filename = os.path.join(os.path.abspath(self.params['out_dir']),
                                       fname)
        return output_filename

    def setup_output(self):

        # output times
        times = self._get_output_times(
            freq=None, period_ending=self.params['period_ending'])[0]

        # Number of timesteps
        n_ts = len(times)

        shape = (n_ts, ) + self.domain['mask'].shape
        dims = ('time', ) + self.domain['mask'].dims
        coords = {'time': times, **self.domain['mask'].coords}
        self.output = xr.Dataset(coords=coords)
        self.output['time'].encoding['calendar'] = self.params['calendar']

        dtype = self.params['out_precision']
        for varname in self.params['out_vars']:
            self.output[varname] = xr.DataArray(data=np.full(shape,
                                                             np.nan,
                                                             dtype=dtype),
                                                coords=coords,
                                                dims=dims,
                                                name=varname,
                                                attrs=attrs.get(varname, {}))
        self.output['time'].attrs.update(attrs['time'])

    def _aggregate_state(self):
        """Aggregate data out of the state file and load it into `met_data`"""
        # Precipitation record

        assert self.state.dims['time'] == 90, self.state['time']
        record_dates = date_range(self.params['state_start'],
                                  self.params['state_stop'],
                                  calendar=self.params['calendar'])
        trailing = self.state['prec']
        trailing['time'] = record_dates
        total_precip = xr.concat([trailing, self.met_data['prec']],
                                 dim='time').load()
        total_precip = (
            cnst.DAYS_PER_YEAR * total_precip.rolling(time=90).mean().sel(
                time=slice(self.params['start'], self.params['stop'])))

        self.met_data['seasonal_prec'] = total_precip

        # Smoothed daily temperature range
        trailing = self.state['t_max'] - self.state['t_min']

        trailing['time'] = record_dates
        dtr = self.met_data['t_max'] - self.met_data['t_min']
        if (dtr < 0).any():
            raise ValueError("Daily maximum temperature lower"
                             " than daily minimum temperature!")
        sm_dtr = xr.concat([trailing, dtr], dim='time').load()
        sm_dtr = sm_dtr.rolling(time=30).mean().drop(record_dates, dim='time')
        self.met_data['dtr'] = dtr
        self.met_data['smoothed_dtr'] = sm_dtr

    def _validate_setup(self):
        """Updates the global parameters dictionary"""
        errs = [""]

        # Make sure there's some input
        if not len(self.params.get('forcing', [])):
            errs.append("Requires input forcings to be specified")

        # Make sure there is at least one forcing_var
        # They cannot all be constant since we use one as a template
        # for the others
        if not len(self.params.get('forcing_vars', [])):
            errs.append("Requires at least one non-constant forcing")

        # Parameters that can't be empty strings or None
        non_empty = ['out_dir', 'time_step', 'forcing_fmt']
        for each in non_empty:
            if self.params.get(each, None) is None or self.params[each] == '':
                errs.append("Cannot have empty value for {}".format(each))

        # Make sure time step divides evenly into a day
        if (cnst.MIN_PER_DAY % int(self.params.get('time_step', -1))
                or (int(self.params['time_step']) > (6 * cnst.MIN_PER_HOUR)
                    and int(self.params['time_step']) != cnst.MIN_PER_DAY)):
            errs.append("Time step must be evenly divisible into 1440 "
                        "minutes (24 hours) and less than 360 minutes "
                        "(6 hours). Got {}.".format(self.params['time_step']))

        # Check for required input variable specification
        if self.met_data is not None:
            required_in = ['t_min', 't_max', 'prec']
            for each in required_in:
                if each not in self.met_data.variables:
                    errs.append("Input requires {}".format(each))

        # Make sure that we are going to write out some data
        if not len(self.params.get('out_vars', [])):
            errs.append("Output variable list must not be empty")

        # Check output variables are valid
        daily_out_vars = [
            't_min', 't_max', 't_day', 'prec', 'vapor_pressure', 'shortwave',
            'tskc', 'pet', 'wind', 'daylength'
        ]
        out_var_check = [
            'temp', 'prec', 'shortwave', 'vapor_pressure', 'air_pressure',
            'rel_humid', 'spec_humid', 'longwave', 'tskc', 'wind'
        ]
        if int(self.params.get('time_step', -1)) == 1440:
            out_var_check = daily_out_vars
        for var in self.params.get('out_vars', []):
            if var not in out_var_check:
                errs.append('Cannot output variable {} at timestep {}'.format(
                    var, self.params['time_step']))

        # Check that the parameters specified are available
        opts = {
            'out_precision': ['f4', 'f8'],
            'lw_cloud': ['default', 'cloud_deardorff'],
            'lw_type': [
                'default', 'tva', 'anderson', 'brutsaert', 'satterlund',
                'idso', 'prata'
            ]
        }
        for k, v in opts.items():
            if not self.params.get(k, None) in v:
                errs.append("Invalid option given for {}".format(k))

        # If any errors, raise and give a summary
        if len(errs) > 1:
            raise Exception("\n  ".join(errs))

    def get_nc_output_suffix(self, times):
        s, e = times[[0, -1]]
        template = '{:04d}{:02d}{:02d}-{:04d}{:02d}{:02d}'
        return template.format(
            s.year,
            s.month,
            s.day,
            e.year,
            e.month,
            e.day,
        )
示例#12
0
def calcule_stat_climat(l_fichiers_nc,
                        nom_var,
                        l_type_periode,
                        fichier_sortie,
                        verbose=1,
                        date_debut=None,
                        date_fin=None):
    """fonction qui calcule les stat climatiques a partir d'une liste
    de fichiers netCDF

    """
    # verification de la coherence des arguments
    assert (len(l_fichiers_nc) > 0)

    # ouverture des fichiers d'entree
    ds = nc_to_xr_dataset(l_fichiers_nc, patch_xr_open_mfdataset=True)
    dates = ds.time
    # on chercher la frequence des donnees d'entree
    freq_entree = trouve_freq(dates)

    print()
    print(("frequence des donnees en entree: {:s}".format(freq_entree)))
    assert (freq_entree != 'NA')

    # limite aux donnees mensuelles pour l'instant
    assert (freq_entree == 'mois')

    # TODO verifier s'il n'y a pas de trou dans la serie de donnees mensuelles

    # selection de la variable voulue pour les dates limites
    da = ds[nom_var].sel(time=slice(date_debut, date_fin))

    # on s'assure d'avoir des donnees a traiter
    assert (da.time.size > 0)

    # affichage des informations
    print(('\ntraitement de la tranche \n{:} \n@ \n{:}'.format(
        da.time.values[0], da.time.values[-1])))
    #superflu? dates = pd.to_datetime(da.time.values)

    # boucle sur les periodes demandees
    for n_f, type_periode in enumerate(l_type_periode):

        # debut du traitement
        print()
        print(
            f'calcul des stat climatiques pour les periodes de type: {type_periode}'
        )

        # si les donnees sont mensuelles on genere les series de
        # moyenne saisonniere ou annuelle en tenant compte du poids
        # de chaques mois ajuste au nombre de jour dans
        # le mois (selon le calendrier)
        calendrier = nc_cherche_calendrier(l_fichiers_nc)
        if freq_entree == 'mois':
            if type_periode == 'saison':
                ds_serie = xr_se_mens_a_sais(da, calendrier)
            elif type_periode == 'annee':
                ds_serie = xr_se_mens_a_an(da, calendrier)
            elif type_periode == 'mois':
                ds_serie = da
            else:
                msg = f'type de periode non prevue: {type_periode}'
                raise RuntimeError(msg)
        else:
            msg = "on ne traite que les fichiers de statistiques mensuelles pour l'instant"
            raise RuntimeError(msg)

        # calcule des stat_climatiques
        #
        # on divise les donnees selon la type_periode voulue
        if type_periode == 'annee':
            l_ds = [ds_serie]
        elif type_periode == 'saison':
            l_ds = []
            for saison in 'DJF MAM JJA SON'.split():
                ds_saison = ds_serie.sel(
                    time=ds_serie.time.dt.season == saison)
                l_ds.append(ds_saison)
        elif type_periode == 'mois':
            l_ds = []
            for mois in range(1, 13):
                ds_mois = ds_serie.sel(time=ds_serie.time.dt.month == mois)
                l_ds.append(ds_mois)

        # calcul des stat climatiques
        ds_stat_climat_periode = xr_calcule_stat_climat(l_ds)

        # on genere un attribut time.periode identifiant chaque periode
        # (c. a d. nom de la saison, numero du moi ou 'annee')
        if type_periode == 'annee':
            etik = ['AN']
        elif type_periode == 'mois':
            etik = [
                str(m) for m in ds_stat_climat_periode.time.dt.month.values
            ]
        elif type_periode == 'saison':
            # l_saison = 'DJF DJF MAM MAM MAM JJA JJA JJA SON SON SON DJF'.split()
            # etik = [l_saison[m - 1] for m in ds_stat_climat_periode.time.dt.month]
            etik = ds_stat_climat_periode.time.dt.season.values
        ds_stat_climat_periode.coords['periode'] = ('time', etik)

        # ajout des stat de la type_periodeuence aux autre
        if n_f == 0:
            ds_stat_climat = ds_stat_climat_periode
        else:
            ds_stat_climat = xr.concat(
                [ds_stat_climat, ds_stat_climat_periode], 'time')

        # # on enleve la variable etiquette
        # if 'etiquette' in ds_stat_climat:
        #     ds_stat_climat = ds_stat_climat.drop('etiquette')

    # on ajoute des variables lat et lon
    if 'time' in da.lon.dims:
        ds_stat_climat.coords['lon'] = da.lon.isel(time=0).drop('time')
        ds_stat_climat.coords['lat'] = da.lat.isel(time=0).drop('time')
    else:
        ds_stat_climat.coords['lon'] = da.lon
        ds_stat_climat.coords['lat'] = da.lat

    # sauvegarde dans le fichier netcdf
    #
    # on fait une copie avec .bak si le fichier de sortie existe deja
    if os.path.exists(fichier_sortie):
        os.rename(fichier_sortie, fichier_sortie + '.bak')
    # on prepare l'encodage des differentes variables
    encode_float = dict(zlib=True, complevel=6, dtype='float32')
    encode_double = dict(zlib=True, complevel=6, dtype='float64')
    encoding = {var: encode_float for var in ds_stat_climat.data_vars}
    encoding.update({var: encode_double for var in ds_stat_climat.coords})
    encoding['climatology_bounds'] = encode_double
    # on enleve l'encodage sur 'periode'
    tampon = encoding.pop('periode')

    # on sauve le fichier
    dj = ds_stat_climat.to_netcdf(path=fichier_sortie,
                                  encoding=encoding,
                                  mode='w',
                                  format='NETCDF4_CLASSIC',
                                  compute=False)
    print()
    print(f"sauvegarde dans le fichier {fichier_sortie}")
    print()

    with ProgressBar():
        dj.compute()

    # on ferme le dataset
    ds.close()
def cbe(fpaths_lm,
        downsample,
        clustering,
        features,
        normalize_vol=False,
        presample=None,
        cfor=None,
        standardize='default',
        custom_feature_funcs=None,
        bw_method=None,
        dask_graph_path=None,
        processes=None,
        profiling=False,
        suffix_out='default',
        save_metadata=True,
        save_presampled=False,
        save_cfor=False,
        verbose=False,
        legacy=False):
    """Create a feature space from a set of point clouds by cluster-based
    embedding (CBE).

    This includes the following steps:
        1. Loading a set of point clouds
        2. Normalizing point clouds by volume (optional)
        3. Down-sampling of each point cloud individually (optional)
            - Available options are random, kmeans or custom downsampling
        4. Making point clouds invariant to spatial transformation (optional)
            - Also called the "Cell Frame Of Reference" (CFOR)
            - There are currently 2 ways of accomplishing this
                - Transform to pairwise distance space (PD)
                - Transform to PCA space (PCA) [DEPRECATED]
            - It is also possible to pass a custom transform function.
        5. Merging point clouds
        6. Downsampling of merged point clouds (optional but recommended!)
            - Reduces computational cost/scaling of subsequent step
            - Options are density-dep., kmeans, random or custom downsampling
        7. Extracting cluster centers as common reference points
            - Options are kmeans, dbscan and custom clustering
        8. Extracting "cluster features" relative to the reference points
            - Done with dask for effecient chaining of operations
            - Multiple feature options available, see below
        9. Saving the resulting feature space as well as intermediate results

    Cluster features that can be extracted:
        - "kNN-distsManh"  : Manhatten distance in all dimensions of each
                             cluster to the mean point of its k nearest
                             neighbor landmarks.
        - "kNN-distEuclid" : Euclidean distance of each cluster to the mean
                             point of its k nearest neighbor landmarks.
        - "NN-distsManh"   : Manhatten distance in all dimensions of each
                             cluster to the nearest neighboring landmark.
        - "NN-distEuclid"  : Euclidean distance of each cluster to the nearest
                             neighboring landmark.
        - "count-near"     : Number of landmarks near to the cluster, where
                             'near' is the mean distance of the k nearest
                             neighbor landmarks of the cluster.
        - "count-assigned" : Number of landmarks assigned to the cluster during
                             the clustering itself.
        - "kde"            : KDE estimated from cell landmarks sampled for each
                             cluster center.
        - custom features  : See custom_feature_funcs in parameters.

    Feature computations are in part dependent on each other. To make this both
    efficient and readable/elegant, dask is used for chaining the feature
    extraction steps appropriately.

    At the end, features are concatenated into a single array of shape
    (cells, features) and then saved for each input stack separately.

    Parameters
    ----------
    fpaths_lm : single string or list of strings
        A path or list of paths (either local from cwd or global) to npy files
        containing cellular landmarks as generated by
        `katachi.tools.assign_landmarks` or `...find_TFOR`.
    downsample : tuple (algorithm, output_size) or None
        A tuple specifying the algorithm to use for downsampling of the merged
        point cloud prior to cluster extraction. Available algorithms are
        "ddds" (density-dependent downsampling), "kmeans" (perform kmeans and
        use cluster centers as new points) or "random". If "default" is passed,
        "ddds" is used.
        Example: ("ddds", 200000).
        Alternatively, if instead of a string denoting the algorithm a callable
        is passed, that callable is used for downsampling.
        The call signature is
        `all_lms_ds = downsample[0](all_lms, downsample)`
        where all_lms is an array of shape (all_landmarks, dimensions) holding
        all input landmarks merged into one point cloud. Since the `downsample`
        tuple itself is passed, additional arguments can be specified in
        additional elements of that tuple. all_lms_ds must be an array of shape
        (output_size, dimensions).
        If None, no downsampling is performed. This is not recommended for
        inputs of relevant sizes (total landmarks > 20000).
        WARNING: downsampling (especially by ddds) can be very expensive for
        large numbers of cells. In those cases, it is recommended to first run
        a representative subsets of the cells and then use the resulting CBE
        clusters to extract features for the entire dataset (using the
        `previous` setting in the `clustering` argument).
    clustering : tuple (algorithm, n_clusters)
        A tuple specifying the algorithm to use for computing the clusters to
        use in cluster-based feature extraction. Available algorithms are
        "kmeans" or "dbscan". If "default" is passed, "kmeans" is used.
        Example: ('kmeans', 10)
        Alternatively, one may pass a tuple `('previous', clustering_object)`,
        where `clustering_object` is a previously fitted clustering instance
        similar to an instantiated and fitted sklearn.cluster.KMeans object. It
        must have the attribute `cluster_centers_`, which is an array of shape
        (clusters, dimensions) and the method `predict`, which given an array
        of shape `(all_landmarks, dimensions)` will return cluster labels for
        each landmark. Clustering objects from previous runs are stored in
        the metadata under the key `"clustobj-"+identifier`.
        Alternatively, if instead of a string denoting the algorithm a callable
        is passed, that callable is used for clustering.
        The call signature is
        `clust_labels, clust_centers = clustering[0](all_lms, clustering)`
        where all_lms is an array of shape (all_landmarks, dimensions) holding
        all input landmarks merged into one point cloud (and downsampled in the
        previous step). Since the `clustering` tuple itself is passed,
        additional arguments can be specified in additional elements of that
        tuple. `clust_labels` must be a 1D integer array assigning each input
        landmark to a corresponding cluster center. `clust_centers` must be an
        array of shape (clusters, dimensions) and contain the coordinates of
        the cluster centers. The first axis must be ordered such that the
        integers in `clust_labels` index it correctly. The number of clusters
        must match n_clusters.
    features : list of strings
        List containing any number of cluster features to be extracted. The
        strings noted in the explanation above are allowed. If custom feature
        extraction functions are passed (see below), their names must also be
        included in this list.
        Example: ["kNN-distEuclid", "count-near"]
    normalize_vol : bool, optional, default False
        If True, the volume of each input point cloud is normalized by dividing
        each landmark vector magnitude by the sum of all magnitudes.
    presample : tuple (algorithm, output_size) or None, optional, default None
        If not None, the algorithm specified is used to downsample each input
        cloud individually to output_size points. Available algorithms are
        "kmeans" (perform kmeans and use cluster centers as new points) or
        "random".
        Example: ('random', 50)
        Alternatively, if instead of a string denoting the algorithm a callable
        is passed, that callable is used for downsampling.
        The call signature is
        ```for cell in range(lms.shape[0]):
               lms_ds[cell,:,:] = presample[0](lms[cell,:,:], presample)```
        where lms is an array of shape (cells, landmarks, dimensions) holding
        the set of input point clouds. Since the `presample` tuple itself is
        passed, additional arguments can be specified in additional elements of
        that tuple. lms_ds must be an array of shape
        (cells, output_size, dimensions).
        If None, no presampling is performed.
    cfor : tuple (algorithm, dimensions) or None, optional, default None
        A tuple specifying the algorithm to use for recasting the landmarks in
        a space that is invariant to spatial transformations. There are two
        options available: "PD" (pairwise distance transform) and "PCA"
        (per-cell PCA and transform).
        For "PD", the total complement of pairwise distances between all points
        is computed and then subsampled to `dimensions` by selecting a
        corresponding number of distance percentiles in a linear range between
        the 10th to the 90th percentile (inclusive).
        For "PCA", the number of dimensions in the resulting space is equal to
        the number of dimensions of the input (should be 3). The `dimensions`
        part of the argument is ignored (but it must still be suplied!).
        If "default" is passed, "PD" is used.
        Example 1: ('PD', 6)
        Example 2: ('default', 6)  # defaults to 'PD'
        Example 3: ('PCA', 3)
        Alternatively, if a callable is passed instead of a stringm that
        callable is used for downsampling.
        The call signature is
        ```for cell in range(lms.shape[0]):
               lms_cfor[cell,:,:] = cfor[0](lms[cell,:,:], cfor)```
        where lms is an array of shape (cells, landmarks, dimensions) holding
        the set of input point clouds. Since the `cfor` tuple itself is passed,
        additional arguments can be specified in additional elements of that
        tuple. lms_ds must be an array of shape
        (cells, output_size, dimensions).
        If None, no transformation is performed; cells are left in the original
        3D space.
    standardize : bool or 'default', optional, default 'default'
        If True, the point cloud dimensions of the merged CFOR point cloud are
        standardised to zero mean and unit variance. This is also propagated
        to the individual clouds used for feature extraction and for saving
        in case the CFOR is being saved.
        If 'default', standardization is performed only if cfor is set to "PD".
        If False, no standardization is performed.
    custom_feature_funcs : list of tuples or None, optional, default None
        List used to specify one or more custom feature extraction functions.
        Each custom function is specified through a tuple in the list that
        is structured as such:
            `(feature_name, extraction_func, parent_names, other_params)`
        where `feature_name` is the name of the feature as it appears in the
        `features` argument, `extraction_func` is a callable, `parent_names`
        is a lsit of parent feature names (as they appear in the `features`
        argument) used as input to `extraction_func`, and `other_params` is a
        list of other parameters for `extraction_func`.
        The call signature is
        ```dask_graph[custom_func[0]+"_%i" % c] =
               (feature_name, [parent+"_%i" % c for parent in parent_names],
                other_params, lms[c,:,:], clust_centers, clust_labels[c]) ```
        within the dask graph, where `c` is the index of a cell.
        The callable must therefore accept a list of parent features (can be
        an empty list), a list of other parameters (can alos be empty), the
        (preprocessed) landmarks of the given cell, the cluster centers and
        the cluster labels of the given cell.
        It must return a 1D array of float values; the feature vector for the
        current cell `c`.
    bw_method : str, scalar, callable or None, optional, default None
        The method used to calculate the estimator bandwidth for the gaussian
        kde when computing the "kde" feature. This can be ‘scott’, ‘silverman’,
        a scalar constant or a callable. If a scalar, this will be used
        directly as `kde.factor`. If a callable, it should take a gaussian_kde
        instance as only parameter and return a scalar. If None (default),
        ‘scott’ is used. This is ignored if "kde" is not in `features`.
        < Modified from `scipy.stats.gaussian_kde` doc string. >
    dask_graph_path : string or None, optional, default None
        If a path (including a file ending matching a known image format, such
        as '.png') is specified as a string, a dask graph image is created that
        summarizes the feature extraction pipeline for the first 3 cells.
        Note: If the resulting graph contains multiple separate graphs, the
        only relevant graph is the one leading into `fspace` as an end result.
    processes : int or None, optional, default None
        Number of processes to use in multiprocessed and dask-controlled
        operations. If None, a number equal to half the available PCUs is used.
        If `1` (one), no multiprocessing is performed and `dask.get` is used
        instead of `dask.threaded.get`.
    profiling : bool, optional, default False
        If True, dask resource profiling is performed and visualized after the
        pipeline run is finished. This may generate a `profile.html` file in
        the working directory [bug in dask].
    suffix_out : 'default' or dict, optional, default 'default'
        If 'default', the ouput is saved using '_PRES', '_CFOR', '_DS', and
        '_CBE' as suffices for the presampled landmarks (if `presample` is not
        None), for the CFOR-transformed landmarks (if `cfor` is not None), for
        overlayed downsampling (if `downsample` is not None)(note that this is
        not saved explicitly but is part of the suffix for the CBE-embedded
        feature space), and for the CBE-embedded feature space, respectively.
        The suffices are chained as appropriate. If a dict is passed, each of
        these suffices can be specified manually using the keys 'PRES', 'CFOR',
        'DS', 'CBE' and 'META'.
        The suffix specified in 'META' is added to all relevant metadata
        dictionary keys. For any suffices not specified in the suffix_out dict,
        the 'default' suffix is used.
    save_metadata : bool, optional, default True
        If True, cluster samples, cluster labels and a feature header are saved
        to the metadata of each input stack as appropriate.
    save_presampled : bool, optional, default False
        If True, the result of the presampling step is saved with the suffix
        "PRES" for later use.
    save_cfor : bool, optional, default False
        If True, the result of the cfor step is saved with the suffix "CFOR"
        for later use.
    verbose : bool, optional, default False
        If True, more information is printed.
    legacy : bool, optional, default False
        If True (and standardize is also set to True), the feature extraction
        is not performed in standardized space. Instead, the cluster centroids
        are transformed back to the un-standardized space.
        Triggers a deprecation warning.
    """

    #--------------------------------------------------------------------------

    ### Load data

    if verbose: print "Loading data..."

    # Handle cases of single paths
    if type(fpaths_lm) == str:
        fpaths_lm = [fpaths_lm]
    if len(fpaths_lm) == 1:
        warn(
            "fpaths_lm specifies only a single path. Usually, multiple paths" +
            " are specified so that many samples can be overlayed for" +
            " feature extraction!")

    # Import the landmark data
    # Note: The order of fpaths_lm is maintained and an index array is created!
    lms = []
    lms_idx = []
    for idx, fpath_lm in enumerate(fpaths_lm):
        try:
            lms_in = np.load(fpath_lm)
            lms.append(lms_in)
            lms_idx += [idx for i in range(lms_in.shape[0])]
        except:
            print "Attempting to load landmark data from " + str(fpath_lm),
            print "failed with this error:"
            raise
    lms_idx = np.array(lms_idx, dtype=np.int)
    lms = np.concatenate(lms)
    if verbose: print "Total input data shape:", lms.shape

    # Check if downsampling is specified
    if downsample is None:
        warn("It is highly recommended to use downsampling (unless the data " +
             "set is very small)!")

    # Handle processes being None
    if processes is None:
        processes = cpu_count() // 2

    # Handle standardize being default
    if standardize == 'default':
        standardize = False
        if cfor[0] == 'PD':
            standardize = True

    # Handle legacy mode
    if legacy:
        warn("Running in LEGACY mode! This is DEPRECATED!", DeprecationWarning)

    #--------------------------------------------------------------------------

    ### Normalize volume [per cell]

    if normalize_vol:
        if verbose: print "Normalizing volumes..."
        lms = vol_normalize(lms, verbose=verbose)

    #--------------------------------------------------------------------------

    ### Individual downsampling (presampling) [per cell]

    if presample is not None:
        if verbose: print "Presampling..."

        # Prep
        lms_ps = np.zeros((lms.shape[0], presample[1], lms.shape[2]))

        # Random subsampling
        if presample[0] == 'random':
            for cell in range(lms.shape[0]):
                lms_ps[cell, :, :] = ds.random_subsample(
                    lms[cell, :, :], presample[1])

        # Kmeans-based downsampling
        elif presample[0] == 'kmeans':
            for cell in range(lms.shape[0]):
                lms_ps[cell, :, :] = ds.kmeans_subsample(
                    lms[cell, :, :], presample[1])

        # Custom downsampling function
        elif callable(presample[0]):
            for cell in range(lms.shape[0]):
                lms_ps[cell, :, :] = presample[0](lms[cell, :, :], presample)

        # Handle other cases
        else:
            raise ValueError("Invalid presampling method: " +
                             str(presample[0]))

        # Assign the downsampled data back
        lms = lms_ps

    #--------------------------------------------------------------------------

    ### Transform to "Cell Frame Of Reference" (CFOR) [per cell]

    if cfor is not None:
        if verbose: print "Transforming to CFOR..."

        # Prep
        lms_cfor = np.zeros((lms.shape[0], lms.shape[1], cfor[1]))

        # Pairwise distance transform
        if cfor[0] == 'PD' or cfor[0] == 'default':
            for cell in range(lms.shape[0]):
                lms_cfor[cell, :, :] = pd_transform(lms[cell, :, :],
                                                    percentiles=cfor[1])

        # PCA transform
        elif cfor[0] == 'PCA':
            for cell in range(lms.shape[0]):
                lms_cfor[cell, :, :] = PCA().fit_transform(lms[cell, :, :])

        ## RBF transform by Nystroem embedding
        ## REMOVED: This does not create matched dimensions and thus cannot be
        ##          used for this purpose.
        #if cfor[0] == 'RBF':
        #    for cell in range(lms.shape[0]):
        #        Ny = kernel_approximation.Nystroem(kernel='rbf',
        #                                           gamma=1/lms.shape[1],
        #                                           n_components=cfor[1],
        #                                           random_state=42)
        #        lms_cfor[cell,:,:] = Ny.fit_transform(lms[cell,:,:])

        # Custom CFOR transform
        elif callable(cfor[0]):
            for cell in range(lms.shape[0]):
                lms_cfor[cell, :, :] = cfor[0](lms[cell, :, :], cfor)

        # Handle other cases
        else:
            raise ValueError("Invalid CFOR method: " + str(cfor[0]))

        # Assign the CFOR data back
        lms = lms_cfor

    #--------------------------------------------------------------------------

    ### Collective downsampling (all cells overlayed) [altogether]
    #   Note: This is done to improve cluster retrieval and to make it more
    #         efficient. It does not affect the feature extraction afterwards.

    # Flatten cells of all samples together
    all_lms = lms.reshape((lms.shape[0] * lms.shape[1], lms.shape[2]))

    # For CFOR-PD: standardize the dimensions
    if standardize and not legacy:

        # Standardize pooled landmarks
        cloud_means = all_lms.mean(axis=0)
        cloud_stds = all_lms.std(axis=0)
        all_lms = (all_lms - cloud_means) / cloud_stds

        # Overwrite unpooled landmarks for feature extraction in standard space
        lms = all_lms.reshape((lms.shape[0], lms.shape[1], lms.shape[2]))

    # Downsampling
    if downsample is not None and clustering[0] != 'previous':
        if verbose: print "Downsampling merged cloud..."

        # Default is density dependent downsampling
        if downsample[0] == 'default' or downsample[0] == 'ddds':
            all_lms_ds = ds.ddds(all_lms,
                                 downsample[1],
                                 presample=downsample[1],
                                 processes=processes)

        # Alternative: kmeans downsampling
        elif downsample[0] == 'kmeans':
            all_lms_ds = ds.kmeans_subsample(all_lms, downsample[1])

        # Alternative: random downsampling
        elif downsample[0] == 'random':
            all_lms_ds = ds.random_subsample(all_lms, downsample[1])

        # Custom downsampling
        elif callable(downsample[0]):
            all_lms_ds = downsample[0](all_lms, downsample)

        # Handle other cases
        else:
            raise ValueError("Invalid downsampling method: " +
                             str(downsample[0]))

    # No downsampling
    else:
        all_lms_ds = all_lms

    # LEGACY: Standardization after downsampling and without overwriting the
    #         unpooled landmarks!
    if legacy and standardize:
        cloud_means = all_lms_ds.mean(axis=0)
        cloud_stds = all_lms_ds.std(axis=0)
        all_lms_ds = (all_lms_ds - cloud_means) / cloud_stds

    #--------------------------------------------------------------------------

    ### Find reference points by clustering [altogether]

    if verbose: print "Clustering to find reference points..."

    # Default: kmeans clustering
    if clustering[0] == 'default' or clustering[0] == 'kmeans':

        # Perform clustering
        my_clust = MiniBatchKMeans(n_clusters=clustering[1], random_state=42)
        my_clust.fit(all_lms_ds)

        # Get labels and centroids
        clust_labels = my_clust.labels_
        clust_centers = my_clust.cluster_centers_

        # Predict labels for whole data set (if downsampled)
        if downsample is not None:
            clust_labels = my_clust.predict(all_lms)

    # To be added: DBSCAN
    elif clustering[0] == 'dbscan':
        raise NotImplementedError("And likely never will be...")

    # Using a given (already fitted) clustering object
    elif clustering[0] == 'previous':
        my_clust = clustering[1]
        clust_centers = my_clust.cluster_centers_
        clust_labels = my_clust.predict(all_lms)

    # Custom alternative
    elif callable(clustering[0]):
        clust_labels, clust_centers = clustering[0](all_lms, clustering)

    # Handle other cases
    else:
        raise ValueError("Invalid clustering method: " + str(clustering[0]))

    # LEGACY: Back-transform of centroids to un-standardized space
    #         In legacy, feature extraction was done on the un-standardized
    #         space, using the back-transformed centroids
    if legacy and standardize:
        clust_centers = clust_centers * cloud_stds + cloud_means

    # Unpool cluster labels
    clust_labels = clust_labels.reshape((lms.shape[0], lms.shape[1]))

    #--------------------------------------------------------------------------

    ### Extract features relative to reference points [per cell]

    if verbose: print "Extracting cluster features..."

    # Init dask graph
    dask_graph = dict()

    # For each cell...
    for c in range(lms.shape[0]):

        # Node to compute kdtree
        dask_graph["kdtree_%i" % c] = (fe.build_kdtree, lms[c, :, :])

        # Nodes for the features
        dask_graph["kNN-distsManh_%i" % c] = (fe.feature_distsManhatten_kNN,
                                              "kdtree_%i" % c, lms[c, :, :],
                                              clust_centers)

        dask_graph["kNN-distEuclid_%i" % c] = (fe.feature_distEuclidean_kNN,
                                               "kNN-distsManh_%i" % c,
                                               lms.shape[2])

        dask_graph["NN-distsManh_%i" % c] = (fe.feature_distsManhatten_NN,
                                             "kdtree_%i" % c, lms[c, :, :],
                                             clust_centers)

        dask_graph["NN-distEuclid_%i" % c] = (fe.feature_distEuclidean_NN,
                                              "NN-distsManh_%i" % c,
                                              lms.shape[2])

        dask_graph["count-near_%i" % c] = (fe.feature_count_near, [
            "kdtree_%i" % c, "kNN-distEuclid_%i" % c
        ], lms[c, :, :], clust_centers)

        dask_graph["count-assigned_%i" % c] = (fe.feature_count_assigned,
                                               clust_centers, clust_labels[c])

        dask_graph["kde_%i" % c] = (fe.feature_kde, lms[c, :, :],
                                    clust_centers, bw_method)

        # Nodes for custom feature extraction functions
        if custom_feature_funcs is not None:
            for custom_func in custom_feature_funcs:
                custom_parents = [
                    parent + "_%i" % c for parent in custom_func[2]
                ]
                dask_graph[custom_func[0] +
                           "_%i" % c] = (custom_func[1], custom_parents,
                                         custom_func[3], lms[c, :, :],
                                         clust_centers, clust_labels[c])

        # Node to collect requested features for a cell
        dask_graph["fvector_%i" % c] = (fe.assemble_cell,
                                        [f + "_%i" % c
                                         for f in features], features)

        # Render example graph for first 3 cells
        if c == 2 and dask_graph_path is not None:
            from dask.dot import dot_graph
            dask_graph["fspace"] = (fe.assemble_fspace,
                                    ["fvector_%i" % c for c in range(3)])
            dot_graph(dask_graph, filename=dask_graph_path)

    # Final node to put per-cell features into a feature space
    dask_graph["fspace"] = (fe.assemble_fspace,
                            ["fvector_%i" % c for c in range(lms.shape[0])])

    # Run without multiprocessing
    if processes == 1:
        with ProgressBar(dt=1):
            fspace, fheader = dask.get(dask_graph, 'fspace')

    # Run with multiprocessing
    else:

        # Set number of threads
        dask.set_options(pool=ThreadPool(processes))

        # Run the pipeline (no profiling)
        if not profiling:
            with ProgressBar(dt=1):
                fspace, fheader = dask.threaded.get(dask_graph, 'fspace')

        # Run the pipeline (with resource profiling)
        if profiling:
            with ProgressBar(dt=1):
                with Profiler() as prof, ResourceProfiler(dt=0.1) as rprof:
                    fspace, fheader = dask.threaded.get(dask_graph, 'fspace')
                visualize([prof, rprof], save=False)

    #--------------------------------------------------------------------------

    ### Save [per stack], report and return

    if verbose: print "Saving result..."

    # For each stack...
    for sample_idx, sample_fpath in enumerate(fpaths_lm):

        # Prepare suffix
        suffix = ""

        # Save individually downsampled landmark distributions if desired
        if presample is not None and save_presampled:
            if suffix_out == 'default' or 'PRES' not in suffix_out.keys():
                suffix = suffix + "_PRES"
            else:
                suffix = suffix + suffix_out['PRES']
            np.save(sample_fpath[:-4] + suffix,
                    lms_ps[lms_idx == sample_idx, :, :])

        # Save CFOR if desired
        if cfor is not None and save_cfor:
            if suffix_out == 'default' or 'CFOR' not in suffix_out.keys():
                suffix = suffix + "_CFOR"
            else:
                suffix = suffix + suffix_out['CFOR']
            np.save(sample_fpath[:-4] + suffix,
                    lms[lms_idx == sample_idx, :, :])

        # Include downsampling in suffix
        if downsample is not None:
            if suffix_out == 'default' or 'DS' not in suffix_out.keys():
                suffix = suffix + '_DS'
            else:
                suffix = suffix + suffix_out['DS']

        # Save shape space
        if suffix_out == 'default' or 'CBE' not in suffix_out.keys():
            suffix = suffix + "_CBE"
        else:
            suffix = suffix + suffix_out['CBE']
        np.save(sample_fpath[:-4] + suffix, fspace[lms_idx == sample_idx, :])

        # Save new metadata
        if save_metadata:

            # Construct metadata path
            dirpath, fname = os.path.split(sample_fpath)
            fpath_meta = os.path.join(dirpath,
                                      fname[:10] + "_stack_metadata.pkl")

            # Open metadata
            with open(fpath_meta, 'rb') as metafile:
                meta_dict = pickle.load(metafile)

            # Prepare metadata suffix
            if suffix_out == 'default' or 'META' not in suffix_out.keys():
                if suffix[0] == '_':
                    m_suffix = suffix[1:]
                else:
                    m_suffix = suffix
            else:
                if suffix[0] == '_':
                    m_suffix = suffix[1:] + suffix_out['META']
                else:
                    m_suffix = suffix + suffix_out['META']

            # Slightly awkward addition of TFOR tag
            if 'TFOR' in fpaths_lm[0]:
                m_suffix = 'TFOR_' + m_suffix

            # Add new metadata
            meta_dict["clustobj-" + m_suffix] = my_clust
            meta_dict["clusters-" + m_suffix] = clust_centers
            meta_dict["labels-" +
                      m_suffix] = clust_labels[lms_idx == sample_idx]
            meta_dict["features-" + m_suffix] = fheader

            # Write metadata
            with open(fpath_meta, 'wb') as metafile:
                pickle.dump(meta_dict, metafile, pickle.HIGHEST_PROTOCOL)

    # Report and return
    if verbose: print "Processing complete!"
    return
示例#14
0
文件: ms_helper.py 项目: mfkiwl/disko
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0):
    '''
        Use dask-ms to load the necessary data to create a telescope operator
        (will use uvw positions, and antenna positions)
        
        -- res_arcmin: Used to calculate the maximum baselines to consider.
                       We want two pixels per smallest fringe
                       pix_res > fringe / 2
                       
                       u sin(theta) = n (for nth fringe)
                       at small angles: theta = 1/u, or u_max = 1 / theta
                       
                       d sin(theta) = lambda / 2
                       d / lambda = 1 / (2 sin(theta))
                       u_max = lambda / 2sin(theta)
                       
                       
    '''
    with scheduler_context():
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
            #print(ant_ds)
            #print(dask.compute(ant_ds.NAME.data,
            #ant_ds.POSITION.data,
            #ant_ds.DISH_DIAMETER.data))
            ant_p = np.array(ant_ds.POSITION.data)
        logger.info("Antenna Positions {}".format(ant_p.shape))

        # Create a dataset representing the field
        field_table = '::'.join((ms, 'FIELD'))
        for field_ds in xds_from_table(field_table):
            #print(ant_ds)
            #print(dask.compute(ant_ds.NAME.data,
            #ant_ds.POSITION.data,
            #ant_ds.DISH_DIAMETER.data))
            phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten()
        logger.info("Phase Dir {}".format(np.degrees(phase_dir)))

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            #print(spw_ds)
            #print(spw_ds.NUM_CHAN.values)
            logger.info("CHAN_FREQ.values: {}".format(
                spw_ds.CHAN_FREQ.values.shape))
            frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten()
            frequency = frequencies[channel]
            logger.info("Frequencies = {}".format(frequencies))
            logger.info("Frequency = {}".format(frequency))
            logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0])

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(ms, chunks={'row': chunks}))

        pol = 0

        for ds in datasets:
            logger.info("DATA shape: {}".format(ds.DATA.data.shape))
            logger.info("UVW shape: {}".format(ds.UVW.data.shape))

            uvw = np.array(ds.UVW.data)  # UVW is stored in meters!
            ant1 = np.array(ds.ANTENNA1.data)
            ant2 = np.array(ds.ANTENNA2.data)
            flags = np.array(ds.FLAG.data)
            cv_vis = np.array(ds.DATA.data)[:, channel, pol]
            epoch_seconds = np.array(ds.TIME.data)[0]

            # Try write the STATE_ID column back
            write = xds_to_table(ds, ms, 'STATE_ID')
            with ProgressBar(), Profiler() as prof:
                write.compute()

            # Profile
            #prof.visualize(file_path="chunked.html")

        ### NOW REMOVE DATA THAT DOESN'T FIT THE IMAGE RESOLUTION

        u_max = get_resolution_max_baseline(res_arcmin, frequency)

        logger.info("Resolution Max UVW: {:g}".format(u_max))
        logger.info("Flags: {}".format(flags.shape))

        # Now report the recommended resolution from the data.
        # 1.0 / 2*np.sin(theta) = limit_u
        limit_uvw = np.max(np.abs(uvw), 0)
        res_limit = get_baseline_resolution(limit_uvw[0], frequency)
        logger.info("Nyquist resolution: {:g} arcmin".format(
            np.degrees(res_limit) * 60.0))

        #maxuvw = np.max(np.abs(uvw), 1)
        #logger.info(np.random.choice(maxuvw, 100))

        if False:
            good_data = np.array(np.where(flags[:, channel,
                                                pol] == 0)).T.reshape((-1, ))
        else:
            good_data = np.array(
                np.where((flags[:, channel, pol] == 0)
                         & (np.max(np.abs(uvw), 1) < u_max))).T.reshape((-1, ))
        logger.info("Good Data {}".format(good_data.shape))

        logger.info("Maximum UVW: {}".format(limit_uvw))
        logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0)))

        n_ant = len(ant_p)

        good_vis = cv_vis[good_data]

        n_max = len(good_vis)

        indices = np.random.choice(good_data, min(num_vis, n_max))

        hdr = {
            'CTYPE1': ('RA---SIN', "Right ascension angle cosine"),
            'CRVAL1': np.degrees(phase_dir)[0],
            'CUNIT1': 'deg     ',
            'CTYPE2': ('DEC--SIN', "Declination angle cosine "),
            'CRVAL2': np.degrees(phase_dir)[1],
            'CUNIT2': 'deg     ',
            'CTYPE3': 'FREQ    ',  #           / Central frequency  ",
            'CRPIX3': 1.,
            'CRVAL3': "{}".format(frequency),
            'CDELT3': 10026896.158854,
            'CUNIT3': 'Hz      ',
            'EQUINOX': '2000.',
            'DATE-OBS': "{}".format(epoch_seconds),
            'BTYPE': 'Intensity'
        }

        #from astropy.wcs.utils import celestial_frame_to_wcs
        #from astropy.coordinates import FK5
        #frame = FK5(equinox='J2010')
        #wcs = celestial_frame_to_wcs(frame)
        #wcs.to_header()

        u_arr = uvw[indices, 0]
        v_arr = uvw[indices, 1]
        w_arr = uvw[indices, 2]

        cv_vis = cv_vis[indices]

        # Convert from reduced Julian Date to timestamp.
        timestamp = datetime.datetime(
            1858, 11, 17, 0, 0, 0,
            tzinfo=datetime.timezone.utc) + datetime.timedelta(
                seconds=epoch_seconds)

        return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp
示例#15
0
def calc(datafiles,
         lat0,
         lon0,
         hs="phs.|hs.",
         tp="ptp.|tp.",
         dp="pdir.|th.",
         si=20,
         groupers=None,
         nblocks=1):
    """Calculate ESTELA dataset for a target point.

    Args:
        datafiles (str or sequence of str): Regular expression or list of data files.
        lat0 (float): Latitude of target point.
        lon0 (float): Longitude of target point.
        hs (str or sequence of str): regex/list of hs field names in datafiles.
        tp (str or sequence of str): regex/list of tp field names in datafiles.
        dp (str or sequence of str): regex/list of dp field names in datafiles.
        si (str or sequence of str or float): Value or regex/list of directional spread field names.
        groupers (sequence of str, optional): Values used to group the results.
        nblocks (int, optional): Number of blocks. More blocks need less memory but calculations are slower.

    Returns:
        xarray.Dataset: ESTELA dataset with F and traveltime fields.
    """
    if isinstance(datafiles, str):
        flist = sorted(glob(datafiles))
    else:
        flist = sorted(datafiles)
    print(
        f"{datetime.datetime.utcnow():%Y%m%d %H:%M:%S} Processing {len(flist)} files"
    )
    groupers = get_groupers(groupers)

    lon0 %= 360.0
    lat0_arr = xr.DataArray(dims="site", data=np.array(lat0).flatten())
    lon0_arr = xr.DataArray(dims="site", data=np.array(lon0).flatten())
    sites = xr.Dataset(dict(lat0=lat0_arr, lon0=lon0_arr))
    # TODO calculate several sites at the same time. Problematic memory usage but much faster (if data reading is slow)

    dsf = xr.open_mfdataset(flist[0])
    # geographical constants and initialization
    dists, bearings = dist_and_bearing(lat0, dsf.latitude, lon0, dsf.longitude)
    vland = geographic_mask(lat0, lon0, dists, bearings)
    dist_m = dists * 6371000 * D2R
    va = 1.4 * 10**-5
    rowroa = 1 / 0.0013
    # sigma = 2 * np.pi / ds.tp  # Lemax = (rowroa * 9.81 ** 2) / (4 * sigma ** 3 * (2 * va * sigma) ** 0.5)
    k_dissipation = (-dist_m / (rowroa * 9.81**2) * 4 * (2 * va)**0.5 *
                     (2 * np.pi)**3.5
                     )  # coef_dissipation = np.exp(-dist_m / Lemax)
    th1_sin = np.sin(0.5 * bearings * D2R)
    th1_cos = np.cos(0.5 * bearings * D2R)

    # S and Stp calculations
    si_calculations = True
    grouped_results = dict()
    for f in flist:
        dsf = xr.open_mfdataset(f).chunk("auto")

        if nblocks > 1:
            dsf_blocks = [g for _, g in dsf.groupby_bins("time", nblocks)]
        else:
            dsf_blocks = [dsf]
        print(
            f"{datetime.datetime.utcnow():%Y%m%d %H:%M:%S} Processing {f} (nblocks={nblocks})"
        )

        spec_info = dict(hs=hs, tp=tp, dp=dp, si=si)
        for k, value in spec_info.items():
            if isinstance(value, str):  # expand regular expressions
                spec_info[k] = sorted(v for v in dsf.variables
                                      if re.fullmatch(value, v))
        npart = len(spec_info["hs"])
        num_si = isinstance(spec_info["si"], (int, float))
        print(spec_info)

        for dsi in dsf_blocks:
            block_results = xr.Dataset()
            for ipart in range(npart):
                hs_data = dsi[spec_info["hs"][ipart]]
                tp_data = dsi[spec_info["tp"][ipart]]
                dp_data = dsi[spec_info["dp"][ipart]]

                coef_dissipation = np.exp(k_dissipation * (tp_data**-3.5))

                if si_calculations:
                    if num_si:  # don't repeat calculations
                        si_data = spec_info["si"]
                        si_calculations = False
                    else:
                        si_data = dsi[spec_info["si"][ipart]].clip(15., 45.)
                        # TODO find better solution to avoid invalid A2 values

                    s = (2 / (si_data * np.pi / 180)**2) - 1
                    A2 = special.gamma(s + 1) / (special.gamma(s + 0.5) * 2 *
                                                 np.pi**0.5)
                    # TODO find faster spread approach (use normal distribution or table?)
                    coef_spread = A2 * np.pi / 180  # deg
                    # TODO review coef_spread units and compare with wavespectra

                th2 = 0.5 * dp_data * D2R
                coef_direction = abs(
                    np.cos(th2) * th1_cos + np.sin(th2) * th1_sin)**(2.0 * s)

                Spart_th = hs_data**2 / 16 * coef_dissipation * coef_direction * coef_spread
                block_results["S_th"] = block_results.get("S_th",
                                                          0) + (Spart_th)
                block_results["Stp_th"] = block_results.get(
                    "Stp_th", 0) + (tp_data * Spart_th)

            with ProgressBar():
                block_results.load()

            for grouper in groupers:
                if grouper == "ALL":
                    grouped_results["ALL"] = grouped_results.get(
                        "ALL", 0) + block_results.sum("time").assign(
                            ntime=len(dsi.time))
                else:
                    for k, v in block_results.groupby(grouper):
                        kstr = f"m{k:02g}" if grouper == "time.month" else str(
                            k)
                        grouped_results[kstr] = grouped_results.get(
                            kstr, 0) + v.sum("time").assign(ntime=len(v.time))

    # Saving estelas
    time = xr.Variable(data=sorted(grouped_results), dims="time")
    estelas_aux = xr.concat([grouped_results[k] for k in time.values],
                            dim=time)
    # TODO Te instead of Tp.  tp_te_ratio = 1.1 ?
    Fdeg = (1.025 * 9.81 * estelas_aux["Stp_th"] / estelas_aux["ntime"] *
            9.81 / 4 / np.pi)
    cg_mps = (estelas_aux["Stp_th"] / estelas_aux["S_th"]) * 9.81 / 4 / np.pi
    estelas_dict = {
        "F": 360 * Fdeg,
        "traveltime": (3600 * 24 * cg_mps / dist_m)**-1
    }  # dimensions order tyx
    estelas = xr.Dataset(estelas_dict).where(vland, np.nan).merge(sites)
    estelas.F.attrs["units"] = "$\\frac{kW}{m\\circ}$"
    estelas.traveltime.attrs["units"] = "days"
    estelas.attrs["start_time"] = str(
        xr.open_mfdataset(flist[0]).time[0].values)
    estelas.attrs["end_time"] = str(
        xr.open_mfdataset(flist[-1]).time[-1].values)
    return estelas
示例#16
0
def plot_lateral_flux_site_analysis(ds, lat, lon, date_offset, n_dates):
    logging.info(
        f"Plotting lateral flux site analysis at ({lat}°N, {lon}°E) for {n_dates} dates..."
    )

    time = ds.time.values
    time_slice = slice(date_offset, date_offset + n_dates + 1)
    simulation_time = ds.time.isel(time=time_slice).values

    # Compute column-integrated fluxes and flux differences

    uT = ds.ADVx_TH
    vT = ds.ADVy_TH
    uS = ds.ADVx_SLT
    vS = ds.ADVy_SLT

    with ProgressBar():
        Σdz_uT = uT.integrate(coord="Z").sel(XG=lon, YC=lat,
                                             method="nearest").values
        Σdz_vT = vT.integrate(coord="Z").sel(XC=lon, YG=lat,
                                             method="nearest").values
        Σdz_uS = uS.integrate(coord="Z").sel(XG=lon, YC=lat,
                                             method="nearest").values
        Σdz_vS = vS.integrate(coord="Z").sel(XC=lon, YG=lat,
                                             method="nearest").values

        ΔΣdz_uT = uT.integrate(coord="Z").diff("XG").sel(
            XG=lon, YC=lat, method="nearest").values
        ΔΣdz_vT = vT.integrate(coord="Z").diff("YG").sel(
            XC=lon, YG=lat, method="nearest").values
        ΔΣdz_uS = uS.integrate(coord="Z").diff("XG").sel(
            XG=lon, YC=lat, method="nearest").values
        ΔΣdz_vS = vS.integrate(coord="Z").diff("YG").sel(
            XC=lon, YG=lat, method="nearest").values

    # Plot column-integrated fluxes time series

    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(16, 12))

    fig.suptitle(
        f"LESbrary.jl SOSE site analysis: lateral fluxes at ({lat}°N, {lon}°E)"
    )

    ax_T = axes[0]
    ax_T.plot(time, Σdz_uT, label=r"$\int uT \; dz$")
    ax_T.plot(time, Σdz_vT, label=r"$\int vT \; dz$")
    ax_T.plot(time, ΔΣdz_uT, label=r"$\Delta \int uT \; dz$")
    ax_T.plot(time, ΔΣdz_vT, label=r"$\Delta \int vT \; dz$")
    ax_T.axvspan(simulation_time[0],
                 simulation_time[-1],
                 color='gold',
                 alpha=0.5)
    ax_T.legend(frameon=False)
    ax_T.set_ylabel(r"$\degree C \cdot m^4/s$")
    ax_T.set_xlim([time[0], time[-1]])
    ax_T.set_xticklabels([])

    ax_S = axes[1]
    ax_S.plot(time, Σdz_uS, label=r"$\int uS \; dz$")
    ax_S.plot(time, Σdz_vS, label=r"$\int vS \; dz$")
    ax_S.plot(time, ΔΣdz_uS, label=r"$\Delta \int uS \; dz$")
    ax_S.plot(time, ΔΣdz_vS, label=r"$\Delta \int vS \; dz$")
    ax_S.axvspan(simulation_time[0],
                 simulation_time[-1],
                 color='gold',
                 alpha=0.5)
    ax_S.legend(frameon=False)
    ax_S.set_ylabel(r"$\mathrm{psu} \cdot m^4/s$")
    ax_S.set_xlim([time[0], time[-1]])

    start_date_str = numpy_datetime_to_date_str(simulation_time[0])
    end_date_str = numpy_datetime_to_date_str(simulation_time[-1])
    filename = f"lesbrary_site_analysis_lateral_fluxes_latitude{lat}_longitude{lon}_{start_date_str}_to_{end_date_str}.png"
    logging.info(f"Saving {filename}...")
    plt.savefig(filename, dpi=300)
    plt.close(fig)
示例#17
0
    def on_pbAnomalies_click(self):
        """
        Computes the climatolgy and anomalies for the selected year
        and shows the correspinding plot
        """
        # Wait cursor
        QtWidgets.QApplication.setOverrideCursor(Qt.WaitCursor)

        first_run = False

        # The climatology method will create two datasets:
        #   ts.climatology_mean
        #   ts.climatology_std
        if self.ts.climatology_mean is None and \
            self.ts.climatology_std is None:

            # Compute climatology
            self.ts.climatology()

            with ProgressBar():
                self.ts.climatology_mean = self.ts.climatology_mean.compute()
                self.ts.climatology_std = self.ts.climatology_std.compute()

            self.climatology_year = self.years.currentText()
            first_run = True

        if self.climatology_year is None or \
            self.climatology_year != self.years.currentText() or \
            first_run is True:

            if self.ts.climatology_mean.shape[0] != self.single_year_ds.shape[0]:
                # Standard cursor
                QtWidgets.QApplication.restoreOverrideCursor()

                message_text = (f'Year {self.years.currentText()} does not '
                                f'have same number of observations as the '
                                f'climatology. Anomalies cannot be computed.')
                self.message_box(message_text)
                return None

            # Anomalies
            anomalies = (self.single_year_ds - self.ts.climatology_mean.data) \
                            / self.ts.climatology_std.data

            with ProgressBar():
                self.anomalies = anomalies.compute()

            first_run = False

        self.__plotAnomalies()

        # Output file name
        #smoothing_method = self.smoothing_methods.selectedItems()[0].text()
        #_fname, _ext = os.path.splitext(self.fname)
        #output_fname = f'{_fname}.{smoothing_method}.tif'

        #smoother = Smoothing(fname=self.fname, output_fname=output_fname,
        #        smoothing_methods=[smoothing_method])

        #smoother.smooth()

        # Standard cursor
        QtWidgets.QApplication.restoreOverrideCursor()
示例#18
0
def main(argv=sys.argv[1:]):
    global LOG

    import argparse

    from dask.diagnostics import ProgressBar
    from satpy import Scene
    from satpy.writers import compute_writer_results

    from polar2grid.core.script_utils import (
        create_exc_handler,
        rename_log_file,
        setup_logging,
    )

    add_polar2grid_config_paths()
    USE_POLAR2GRID_DEFAULTS = _get_p2g_defaults_env_var()
    BINARY_NAME = "polar2grid" if USE_POLAR2GRID_DEFAULTS else "geo2grid"

    prog = os.getenv("PROG_NAME", sys.argv[0])
    # "usage: " will be printed at the top of this:
    usage = """
    %(prog)s -h
see available products:
    %(prog)s -r <reader> -w <writer> --list-products -f file1 [file2 ...]
basic processing:
    %(prog)s -r <reader> -w <writer> [options] -f file1 [file2 ...]
basic processing with limited products:
    %(prog)s -r <reader> -w <writer> [options] -p prod1 prod2 -f file1 [file2 ...]
"""
    parser = argparse.ArgumentParser(
        prog=prog,
        usage=usage,
        fromfile_prefix_chars="@",
        description="Load, composite, resample, and save datasets.",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        dest="verbosity",
        action="count",
        default=0,
        help="each occurrence increases verbosity 1 level through " "ERROR-WARNING-INFO-DEBUG (default INFO)",
    )
    parser.add_argument("-l", "--log", dest="log_fn", default=None, help="specify the log filename")
    parser.add_argument(
        "--progress",
        action="store_true",
        help="show processing progress bar (not recommended for logged output)",
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=os.getenv("DASK_NUM_WORKERS", 4),
        help="specify number of worker threads to use (default: 4)",
    )
    parser.add_argument(
        "--extra-config-path",
        action="append",
        help="Specify the base directory of additional Satpy configuration "
        "files. For example, to use custom enhancement YAML file named "
        "'generic.yaml' place it in a directory called 'enhancements' "
        "like '/path/to/my_configs/enhancements/generic.yaml' and then "
        "set this flag to '/path/to/my_configs'.",
    )
    parser.add_argument(
        "--match-resolution",
        dest="preserve_resolution",
        action="store_false",
        help="When using the 'native' resampler for composites, don't save data "
        "at its native resolution, use the resolution used to create the "
        "composite.",
    )
    parser.add_argument(
        "--list-products",
        dest="list_products",
        action="store_true",
        help="List available {} products and exit".format(BINARY_NAME),
    )
    parser.add_argument(
        "--list-products-all",
        dest="list_products_all",
        action="store_true",
        help="List available {} products and custom/Satpy products and exit".format(BINARY_NAME),
    )
    reader_group = add_scene_argument_groups(parser, is_polar2grid=USE_POLAR2GRID_DEFAULTS)[0]
    resampling_group = add_resample_argument_groups(parser, is_polar2grid=USE_POLAR2GRID_DEFAULTS)[0]
    writer_group = add_writer_argument_groups(parser)[0]
    argv_without_help = [x for x in argv if x not in ["-h", "--help"]]

    _retitle_optional_arguments(parser)
    args, remaining_args = parser.parse_known_args(argv_without_help)
    os.environ["DASK_NUM_WORKERS"] = str(args.num_workers)

    # get the logger if we know the readers and writers that will be used
    if args.readers is not None and args.writers is not None:
        glue_name = args.readers[0] + "_" + "-".join(args.writers or [])
        LOG = logging.getLogger(glue_name)
    reader_subgroups = _add_component_parser_args(parser, "readers", args.readers or [])
    writer_subgroups = _add_component_parser_args(parser, "writers", args.writers or [])
    args = parser.parse_args(argv)

    if args.readers is None:
        parser.print_usage()
        parser.exit(
            1,
            "\nERROR: Reader must be provided (-r flag).\n"
            "Supported readers:\n\t{}\n".format("\n\t".join(_supported_readers(USE_POLAR2GRID_DEFAULTS))),
        )
    elif len(args.readers) > 1:
        parser.print_usage()
        parser.exit(
            1,
            "\nMultiple readers is not currently supported. Got:\n\t" "{}\n".format("\n\t".join(args.readers)),
        )
        return -1
    if args.writers is None:
        parser.print_usage()
        parser.exit(
            1,
            "\nERROR: Writer must be provided (-w flag) with one or more writer.\n"
            "Supported writers:\n\t{}\n".format("\n\t".join(_supported_writers(USE_POLAR2GRID_DEFAULTS))),
        )

    reader_args = _args_to_dict(args, reader_group._group_actions)
    reader_names = reader_args.pop("readers")
    scene_creation, load_args = _get_scene_init_load_args(args, reader_args, reader_names, reader_subgroups)
    resample_args = _args_to_dict(args, resampling_group._group_actions)
    writer_args = _args_to_dict(args, writer_group._group_actions)
    writer_specific_args = _parse_writer_args(
        writer_args["writers"], writer_subgroups, reader_names, USE_POLAR2GRID_DEFAULTS, args
    )
    writer_args.update(writer_specific_args)

    if not args.filenames:
        parser.print_usage()
        parser.exit(1, "\nERROR: No data files provided (-f flag)\n")

    # Prepare logging
    rename_log = False
    if args.log_fn is None:
        rename_log = True
        args.log_fn = glue_name + "_fail.log"
    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    setup_logging(console_level=levels[min(3, args.verbosity)], log_filename=args.log_fn)
    logging.getLogger("rasterio").setLevel(levels[min(2, args.verbosity)])
    logging.getLogger("fsspec").setLevel(levels[min(2, args.verbosity)])
    logging.getLogger("s3fs").setLevel(levels[min(2, args.verbosity)])
    logging.getLogger("aiobotocore").setLevel(levels[min(2, args.verbosity)])
    logging.getLogger("botocore").setLevel(levels[min(2, args.verbosity)])
    sys.excepthook = create_exc_handler(LOG.name)
    if levels[min(3, args.verbosity)] > logging.DEBUG:
        import warnings

        warnings.filterwarnings("ignore")
    LOG.debug("Starting script with arguments: %s", " ".join(sys.argv))
    if args.extra_config_path:
        add_extra_config_paths(args.extra_config_path)
    LOG.debug(f"Satpy config path is: {satpy.config.get('config_path')}")

    # Set up dask and the number of workers
    if args.num_workers:
        dask.config.set(num_workers=args.num_workers)

    # Create a Scene, analyze the provided files
    LOG.info("Sorting and reading input files...")
    try:
        scn = Scene(**scene_creation)
    except ValueError as e:
        LOG.error("{} | Enable debug message (-vvv) or see log file for details.".format(str(e)))
        LOG.debug("Further error information: ", exc_info=True)
        return -1
    except OSError:
        LOG.error("Could not open files. Enable debug message (-vvv) or see log file for details.")
        LOG.debug("Further error information: ", exc_info=True)
        return -1

    # Rename the log file
    if rename_log:
        rename_log_file(glue_name + scn.attrs["start_time"].strftime("_%Y%m%d_%H%M%S.log"))

    # Load the actual data arrays and metadata (lazy loaded as dask arrays)
    LOG.info("Loading product metadata from files...")
    reader_info = ReaderProxyBase.from_reader_name(scene_creation["reader"], scn, load_args["products"])
    if args.list_products or args.list_products_all:
        _print_list_products(reader_info, p2g_only=not args.list_products_all)
        return 0

    load_args["products"] = reader_info.get_satpy_products_to_load()
    if not load_args["products"]:
        return -1
    products = load_args.pop("products")
    scn.load(products, **load_args)

    ll_bbox = resample_args.pop("ll_bbox")
    if ll_bbox:
        scn = scn.crop(ll_bbox=ll_bbox)

    scn = filter_scene(
        scn,
        reader_names,
        sza_threshold=reader_args["sza_threshold"],
        day_fraction=reader_args["filter_day_products"],
        night_fraction=reader_args["filter_night_products"],
    )
    if scn is None:
        LOG.info("No remaining products after filtering.")
        return 0

    to_save = []
    areas_to_resample = resample_args.pop("grids")
    if "ewa_persist" in resample_args:
        resample_args["persist"] = resample_args.pop("ewa_persist")
    scenes_to_save = resample_scene(
        scn,
        areas_to_resample,
        preserve_resolution=args.preserve_resolution,
        is_polar2grid=USE_POLAR2GRID_DEFAULTS,
        **resample_args,
    )
    for scene_to_save, products_to_save in scenes_to_save:
        overwrite_platform_name_with_aliases(scene_to_save)
        reader_info.apply_p2g_name_to_scene(scene_to_save)
        to_save = write_scene(
            scene_to_save,
            writer_args["writers"],
            writer_args,
            products_to_save,
            to_save=to_save,
        )

    if args.progress:
        pbar = ProgressBar()
        pbar.register()

    LOG.info("Computing products and saving data to writers...")
    if not to_save:
        LOG.warning(
            "No product files produced given available valid data and "
            "resampling settings. This can happen if the writer "
            "detects that no valid output will be written or the "
            "input data does not overlap with the target grid."
        )
    compute_writer_results(to_save)
    LOG.info("SUCCESS")
    return 0
示例#19
0
    def search(self, n_partitions=1, progress_bar='n'):
        r"""
        Top level search routine.

        Parameters
        ----------
        n_partitions : int
            Number of Dask partitions (processes) to use in parallel. Defaults to single-partition (process).
        progress_bar : str {'y', 'n'}, optional
            Enable command-line progress bar.

        Returns
        -------
        None.

        Notes
        -----
        self.data_handle.cchan_list : the list of coarse channel objects for searching,
             created by self.data_handle = DATAHandle() during __init__() execution.

        If using dask (n_partitions > 1):
        * Launch multiple drift searches in parallel.
        * Each search works on a single coarse channel object.
        * n_partitions governs the maximum number of partitions to run in parallel.
        Else, the searches are done in sequence of the coarse channel objects.

        It is not recommended to mix dask partitions with GPU mode as this could cause GPU queuing.
        """
        t0 = time.time()

        # Make libhdf5 errors visible.  I should not have to do this!
        unsilence_errors()  # from h5py._errors

        filename_in = self.data_handle.filename
        header_in = self.data_handle.header

        # As of 2.1.0, add max_drift_rate and obs_length to FileWriter header input
        header_in['max_drift_rate'] = self.max_drift

        wfilename = filename_in.split('/')[-1].replace('.h5', '').replace(
            '.fits', '').replace('.fil', '')
        path_log = '{}/{}.log'.format(self.out_dir.rstrip('/'), wfilename)
        path_dat = '{}/{}.dat'.format(self.out_dir.rstrip('/'), wfilename)
        if self.append_output:
            logger.debug('Appending DAT and LOG files')
        else:
            logger.debug('Recreating DAT and LOG files')
            if os.path.exists(path_log):
                os.remove(path_log)
            if os.path.exists(path_dat):
                os.remove(path_dat)
        logwriter = LogWriter(path_log)
        filewriter = FileWriter(path_dat, header_in)

        logwriter.info(version_announcements)

        msg = "HDF5 header info: {}\n".format(self.data_handle.get_info())
        logwriter.info(msg)
        print(msg)

        msg = 'Starting ET search with parameters: ' + self.parms + '\n'
        logwriter.info(msg)
        print(msg)

        msg = "Computed drift rate resolution: {}\n".format(
            self.data_handle.drift_rate_resolution)
        logwriter.info(msg)
        print(msg)

        # Run serial version
        if n_partitions == 1:
            sched = Scheduler(load_the_data,
                              [(cchan_obj, self.kernels.precision)
                               for cchan_obj in self.data_handle.cchan_list])
            for cchan_obj in self.data_handle.cchan_list:
                search_coarse_channel(cchan_obj,
                                      self,
                                      dataloader=sched,
                                      filewriter=filewriter,
                                      logwriter=logwriter)
        # Run Parallel version via dask
        else:
            print("FindDoppler.search: Using {} dask partitions".format(
                n_partitions))
            b = db.from_sequence(self.data_handle.cchan_list,
                                 npartitions=n_partitions)
            if progress_bar == 'y':
                with ProgressBar():
                    b.map(search_coarse_channel, self).compute()
            else:
                b.map(search_coarse_channel, self).compute()
            merge_dats_logs(filename_in, self.out_dir, 'dat', cleanup='y')
            merge_dats_logs(filename_in, self.out_dir, 'log', cleanup='y')

        t1 = time.time()
        self.last_logwriter(
            path_log, '\n===== Search time: {:.2f} minutes'.format(
                (t1 - t0) / 60.0))
示例#20
0
    def normalize_intensity(
        self,
        num_std: int = 1,
        divide_by_square_root: bool = False,
        dtype_out: Optional[np.dtype] = None,
    ):
        """Normalize image intensities in inplace to a mean of zero with
        a given standard deviation.

        Parameters
        ----------
        num_std
            Number of standard deviations of the output intensities.
            Default is 1.
        divide_by_square_root
            Whether to divide output intensities by the square root of
            the signal dimension size. Default is False.
        dtype_out
            Data type of normalized images. If None (default), the input
            images' data type is used.

        Notes
        -----
        Data type should always be changed to floating point, e.g.
        ``np.float32`` with
        :meth:`~hyperspy.signal.BaseSignal.change_dtype`, before
        normalizing the intensities.

        Examples
        --------
        >>> import numpy as np
        >>> import kikuchipy as kp
        >>> s = kp.data.nickel_ebsd_small()
        >>> np.mean(s.data)
        146.0670987654321
        >>> s.normalize_intensity(dtype_out=np.float32)  # doctest: +SKIP
        >>> np.mean(s.data)  # doctest: +SKIP
        2.6373216e-08

        Notes
        -----
        Rescaling RGB images is not possible. Use RGB channel
        normalization when creating the image instead.
        """
        if self.data.dtype in rgb_dtypes.values():
            raise NotImplementedError(
                "Use RGB channel normalization when creating the image instead."
            )

        if dtype_out is None:
            dtype_out = self.data.dtype

        dask_array = get_dask_array(self, dtype=np.float32)

        normalized_images = dask_array.map_blocks(
            func=chunk.normalize_intensity,
            num_std=num_std,
            divide_by_square_root=divide_by_square_root,
            dtype_out=dtype_out,
            dtype=dtype_out,
        )

        # Change data type if requested
        if dtype_out != self.data.dtype:
            self.change_dtype(dtype_out)

        # Overwrite signal patterns
        if not self._lazy:
            with ProgressBar():
                print("Normalizing the image intensities:", file=sys.stdout)
                normalized_images.store(self.data, compute=True)
        else:
            self.data = normalized_images
示例#21
0
def return_gsea_capsules(ma=None,
                         tissue='',
                         context_on=False,
                         use_set=False,
                         gsea_superset='H',
                         n_top_sets=25,
                         min_capsule_len=2000,
                         all_genes=False,
                         union_cpgs=True,
                         limited_capsule_names_file=''):
    global gene2cpg, gsea_collections, gene_set_weights
    if limited_capsule_names_file:
        with open(limited_capsule_names_file) as f:
            limited_capsule_names = f.read().replace('\n', ' ').split()
    else:
        limited_capsule_names = []
    allcpgs = ma.beta.columns.values
    entire_sets = use_set
    collection = gsea_superset
    gene2cpg = pickle.load(open(gene2cpg, 'rb'))
    if all_genes:
        gene_sets = list(gene2cpg.keys())
    else:
        gsea = pickle.load(open(gsea_collections, 'rb'))
        if tissue:
            gene_sets = pd.read_csv(gene_set_weights[collection],
                                    sep='\t',
                                    index_col=0)
            if tissue != 'ubiquitous':
                gene_sets = (gene_sets.quantile(1., axis=1) -
                             gene_sets.quantile(
                                 0., axis=1)).sort_values().index.tolist()
            else:
                gene_sets = gene_sets[tissue].sort_values(
                    ascending=False).index.tolist()
        else:
            gene_sets = list(gsea[collection].keys())
    intersect_context = False
    if limited_capsule_names_file:
        gene_sets_tmp = np.intersect1d(gene_sets,
                                       limited_capsule_names).tolist()
        print('LIMITED GENE CAPS', gene_sets_tmp)
        if gene_sets_tmp:
            gene_sets = gene_sets_tmp
            intersect_context = True
    if not tissue:
        n_top_sets = 0
    if n_top_sets and not all_genes:
        gene_sets = gene_sets[:n_top_sets]

    capsules = dict()
    if all_genes:
        entire_sets = False
    if entire_sets:
        context_on = False

    def process_gene_set(gene_set):
        capsules = []
        gene_set_cpgs = []
        for genename in (gsea[collection][gene_set]
                         if not all_genes else [gene_set]):
            gene = gene2cpg.get(genename, {'Gene': [], 'Upstream': []})
            if context_on:
                for k in ['Gene', 'Upstream']:
                    context = gene.get(k, [])
                    if len(context):
                        capsules.append(('{}_{}'.format(genename,
                                                        k), list(context)))
                        #capsules['{}_{}'.format(genename,k)]=context.tolist()
            else:
                if not entire_sets:
                    capsules.append((genename,
                                     np.union1d(gene.get('Gene', []),
                                                gene.get('Upstream',
                                                         [])).tolist()))
                    #capsules[genename]=np.union1d(gene.get('Gene',[]),gene.get('Upstream',[])).tolist()
                else:
                    upstream = gene.get('Upstream', [])
                    gene = gene.get('Gene', [])
                    cpg_set = np.union1d(gene, upstream)
                    if cpg_set.tolist():
                        gene_set_cpgs.append(cpg_set)
        if entire_sets and not all_genes:
            capsules.append((gene_set, reduce(np.union1d,
                                              gene_set_cpgs).tolist()))
            #capsules[gene_set]=reduce(np.union1d,gene_set_cpgs).tolist()
        return capsules

    def process_chunk(chunk):
        with ProgressBar():
            chunk = dask.compute(*chunk, scheduler='threading')
        return chunk

    with ProgressBar():
        capsules = dict(
            list(
                reduce(
                    lambda x, y: x + y,
                    dask.compute(*[
                        dask.delayed(process_gene_set)(gene_set)
                        for gene_set in gene_sets
                    ],
                                 scheduler='threading'))))

    capsules2 = []
    #caps_lens=np.array([len(capsules[capsule]) for capsule in capsules])

    # cluster = LocalCluster(n_workers=multiprocessing.cpu_count()*2, threads_per_worker=20)
    # client = Client(cluster)
    capsule_names = list(capsules.keys())

    if intersect_context:
        capsules_tmp_names = np.intersect1d(capsule_names,
                                            limited_capsule_names).tolist()
        if capsules_tmp_names:
            capsules = {k: capsules[k] for k in capsules_tmp_names}
            capsule_names = capsules_tmp_names

    capsules = reduce_caps(capsules, allcpgs, min_capsule_len)

    # print(capsule_names)
    # capsules_bag=db.from_sequence(list(capsules.values()))
    # capsules_intersect=capsules_bag.map(lambda x: np.intersect1d(x,allcpgs))
    # capsules_len=capsules_intersect.map(lambda x: x if len(x) >= min_capsule_len else [])
    # # with get_task_stream(plot='save', filename="task-stream.html") as ts:
    # capsules=capsules_len.compute()
    # #print(capsules)
    # capsules=dict([(capsule_names[i],capsules[i].tolist()) for i in range(len(capsule_names)) if len(capsules[i])])

    # for capsule in capsules:
    # 	capsules2.append([capsule,dask.delayed(return_caps)(capsules[capsule],allcpgs,min_capsule_len)])
    # cpus=multiprocessing.cpu_count()
    # caps_chunks=list(divide_chunks(capsules2,cpus))
    # p=Pool(cpus)
    # capsules=dict(list(reduce(lambda x,y: x+y,p.map(process_chunk,caps_chunks))))

    # with ProgressBar():
    # 	capsules=dask.compute(capsules2,scheduler='threading')[0]
    #print(capsules)
    modules = list(capsules.values(
    ))  #[capsules[capsule] for capsule in capsules if capsules[capsule]]
    modulecpgs = reduce((np.union1d if union_cpgs else (lambda x, y: x + y)),
                        modules).tolist()
    module_names = list(capsules.keys())

    return modules, modulecpgs, module_names
示例#22
0
def cluster_build_trees(identity,
                        set_name,
                        cluster_file=None,
                        click_loguru=None):
    """Calculate homology clusters, MSAs, trees."""
    options = click_loguru.get_global_options()
    user_options = click_loguru.get_user_global_options()
    parallel = user_options["parallel"]
    set_path = Path(set_name)
    # read and possibly update proteomes
    proteomes_path = set_path / PROTEOMES_FILE
    proteomes_in = read_tsv_or_parquet(proteomes_path)
    proteomes = sort_proteome_frame(proteomes_in)
    if not proteomes_in.equals(proteomes):
        logger.info("proteomes sort order changed, writing new proteomes file")
        write_tsv_or_parquet(proteomes, proteomes_path)
    n_proteomes = len(proteomes)
    # read and update fragment ID's
    frags = read_tsv_or_parquet(set_path / FRAGMENTS_FILE)
    frags["frag.idx"] = pd.array(frags.index, dtype=pd.UInt32Dtype())
    frag_frames = {}
    for dotpath, subframe in frags.groupby(by=["path"]):
        frag_frames[dotpath] = subframe.copy().set_index("frag.orig_id")
    arg_list = []
    concat_fasta_path = set_path / "proteins.fa"
    for i, row in proteomes.iterrows():
        arg_list.append((row, concat_fasta_path, frag_frames[row["path"]]))
    file_idx = {}
    stem_dict = {}
    for i, row in proteomes.iterrows():
        stem = row["path"]
        file_idx[stem] = i
        stem_dict[i] = stem
    if cluster_file is None:
        if concat_fasta_path.exists():
            concat_fasta_path.unlink()
        if not options.quiet:
            logger.info(
                f"Renaming fragments and concatenating sequences for {len(arg_list)}"
                " proteomes:")
        for args in arg_list:
            write_protein_fasta(args)
        del arg_list
        cwd = Path.cwd()
        os.chdir(set_path)
        n_clusters, run_stats, cluster_hist = homology_cluster(
            "proteins.fa",
            identity,
            write_ids=True,
            delete=False,
            cluster_stats=False,
            outname="homology",
            click_loguru=click_loguru,
        )
        log_path = Path("homology.log")
        log_dir_path = Path("logs")
        log_dir_path.mkdir(exist_ok=True)
        shutil.copy2(log_path, "logs/homology.log")
        log_path.unlink()
        os.chdir(cwd)
        logger.info(f"Number of clusters: {n_clusters}")
        del cluster_hist
        del run_stats
        concat_fasta_path.unlink()
    else:  # use pre-existing clusters
        homology_path = set_path / "homology"
        if homology_path.exists():
            shutil.rmtree(homology_path)
        inclusts = pd.read_csv(cluster_file, sep="\t")
        for col in ["cluster_id", "members"]:
            if col not in inclusts.columns:
                logger.error(
                    f'Column named "{col}" not found in external homology cluster file'
                )
                sys.exit(1)
        cluster_counts = inclusts["cluster_id"].value_counts()
        cluster_map = pd.Series(range(len(cluster_counts)),
                                index=cluster_counts.index)
        cluster_ids = inclusts["cluster_id"].map(cluster_map)
        cluster_sizes = inclusts["cluster_id"].map(cluster_counts)
        predef_clusters = pd.DataFrame({
            "cluster_id": cluster_ids,
            "size": cluster_sizes,
            "members": inclusts["members"],
        })
        predef_clusters.sort_values(by=["cluster_id"], inplace=True)
        predef_clusters.drop(
            predef_clusters[predef_clusters["size"] < 2].index,
            axis=0,
            inplace=True,
        )
        n_clusters = predef_clusters["cluster_id"].max() + 1
        predef_clusters.index = range(len(predef_clusters))
        external_cluster_path = set_path / EXTERNAL_CLUSTERS_FILE
        logger.info(
            f"Writing {external_cluster_path} with {len(predef_clusters)} genes"
            + f" in {n_clusters} homology clusters")
        predef_clusters.to_csv(external_cluster_path, sep="\t")
        del cluster_counts, cluster_map, cluster_sizes, inclusts
        homology_path = set_path / "homology"
        homology_path.mkdir(exist_ok=True)
        if not options.quiet:
            logger.info(f"Creating cluster files for for {len(arg_list)}"
                        " proteomes:")
        proteome_no = 0
        for args in arg_list:
            logger.info(f"doing proteome {proteome_no}")
            write_protein_fasta(args,
                                fasta_dir=homology_path,
                                clusters=predef_clusters)
            proteome_no += 1
        del arg_list
        logger.info(
            "Checking that all cluster files are present (gene-id mismatch)")
        missing_files = False
        for i in range(n_clusters):
            if not (homology_path / f"{i}.fa").exists():
                logger.error(f"External cluster {i} is missing.")
                missing_files = True
        if missing_files:
            sys.exit(1)
    #
    # Write homology info back into proteomes
    #
    click_loguru.elapsed_time("Alignment/tree-building")
    hom_mb = DataMailboxes(
        n_boxes=n_proteomes,
        mb_dir_path=(set_path / "mailboxes" / "clusters2proteomes"),
        file_extension="tsv",
    )
    hom_mb.write_tsv_headers(HOMOLOGY_COLS)
    cluster_paths = [
        set_path / "homology" / f"{i}.fa" for i in range(n_clusters)
    ]
    bag = db.from_sequence(cluster_paths)
    cluster_stats = []
    if not options.quiet:
        logger.info(
            f"Calculating MSAs and trees for {len(cluster_paths)} homology"
            " clusters:")
        ProgressBar(dt=SPINNER_UPDATE_PERIOD).register()
    if parallel:
        cluster_stats = bag.map(
            parse_cluster,
            file_dict=file_idx,
            file_writer=hom_mb.locked_open_for_write,
        )
    else:
        for clust_fasta in cluster_paths:
            cluster_stats.append(
                parse_cluster(
                    clust_fasta,
                    file_dict=file_idx,
                    file_writer=hom_mb.locked_open_for_write,
                ))
    n_clust_genes = 0
    clusters_dict = {}
    for cluster_id, cluster_dict in cluster_stats:
        n_clust_genes += cluster_dict["size"]
        clusters_dict[cluster_id] = cluster_dict
    del cluster_stats
    clusters = pd.DataFrame.from_dict(clusters_dict).transpose()
    del clusters_dict
    clusters.sort_index(inplace=True)
    grouping_dict = {}
    for i in range(n_proteomes):  # keep numbering of single-file clusters
        grouping_dict[f"[{i}]"] = i
    grouping_dict[str(list(range(n_proteomes)))] = 0
    for n_members, subframe in clusters.groupby(["n_memb"]):
        if n_members == 1:
            continue
        if n_members == n_proteomes:
            continue
        member_counts = pd.DataFrame(subframe["n_members"].value_counts())
        member_counts["key"] = range(len(member_counts))
        for newcol in range(n_members):
            member_counts[f"memb{newcol}"] = ""
        for member_string, row in member_counts.iterrows():
            grouping_dict[member_string] = row["key"]
            member_list = json.loads(member_string)
            for col in range(n_members):
                member_counts.loc[member_string,
                                  f"memb{col}"] = stem_dict[member_list[col]]
        member_counts = member_counts.set_index("key")
        write_tsv_or_parquet(member_counts,
                             set_path / group_key_filename(n_members))
    clusters["n_members"] = clusters["n_members"].map(grouping_dict)
    clusters = clusters.rename(columns={"n_members": "group_key"})
    n_adj = clusters["n_adj"].sum()
    adj_pct = n_adj * 100.0 / n_clust_genes
    n_adj_clust = sum(clusters["adj_groups"] != 0)
    adj_clust_pct = n_adj_clust * 100.0 / len(clusters)
    logger.info(f"{n_adj} ({adj_pct:.1f}%) out of {n_clust_genes}" +
                " clustered genes are adjacent")
    logger.info(f"{n_adj_clust} ({adj_clust_pct:.1f}%) out of " +
                f"{len(clusters)} clusters contain adjacency")
    write_tsv_or_parquet(clusters, set_path / CLUSTERS_FILE)
    # join homology cluster info to proteome info
    click_loguru.elapsed_time("Joining")
    arg_list = []
    for i, row in proteomes.iterrows():
        arg_list.append((
            i,
            dotpath_to_path(row["path"]),
        ))
    bag = db.from_sequence(arg_list)
    hom_stats = []
    if not options.quiet:
        logger.info(f"Joining homology info to {n_proteomes} proteomes:")
        ProgressBar(dt=SPINNER_UPDATE_PERIOD).register()
    if parallel:
        hom_stats = bag.map(join_homology_to_proteome,
                            mailbox_reader=hom_mb.open_then_delete).compute()
    else:
        for args in arg_list:
            hom_stats.append(
                join_homology_to_proteome(
                    args, mailbox_reader=hom_mb.open_then_delete))
    hom_mb.delete()
    hom_frame = pd.DataFrame.from_dict(hom_stats)
    hom_frame.set_index(["prot.idx"], inplace=True)
    hom_frame.sort_index(inplace=True)
    logger.info("Homology cluster coverage:")
    with pd.option_context("display.max_rows", None, "display.float_format",
                           "{:,.2f}%".format):
        logger.info(hom_frame)
    proteomes = pd.concat([proteomes, hom_frame], axis=1)
    write_tsv_or_parquet(proteomes,
                         set_path / PROTEOMOLOGY_FILE,
                         float_format="%5.2f")
    click_loguru.elapsed_time(None)
示例#23
0
 def process_chunk(chunk):
     with ProgressBar():
         chunk = dask.compute(*chunk, scheduler='threading')
     return chunk
示例#24
0
import inspect
import os
# this is a class to deal with aqs data
from builtins import object, zip

import pandas as pd
from dask.diagnostics import ProgressBar

from .epa_util import read_monitor_file

pbar = ProgressBar()
pbar.register()


def add_data(dates,
             param=None,
             daily=False,
             network=None,
             download=False,
             local=False,
             wide_fmt=True,
             n_procs=1,
             meta=False):
    from ..util import long_to_wide
    a = AQS()
    df = a.add_data(dates,
                    param=param,
                    daily=daily,
                    network=network,
                    download=download,
                    local=local,
示例#25
0
def main(argv=sys.argv[1:]):
    global LOG
    from satpy import Scene
    from satpy.resample import get_area_def
    from satpy.writers import compute_writer_results
    from dask.diagnostics import ProgressBar
    from polar2grid.core.script_utils import (setup_logging, rename_log_file,
                                              create_exc_handler)
    import argparse
    prog = os.getenv('PROG_NAME', sys.argv[0])
    # "usage: " will be printed at the top of this:
    usage = """
    %(prog)s -h
see available products:
    %(prog)s -r <reader> -w <writer> --list-products -f file1 [file2 ...]
basic processing:
    %(prog)s -r <reader> -w <writer> [options] -f file1 [file2 ...]
basic processing with limited products:
    %(prog)s -r <reader> -w <writer> [options] -p prod1 prod2 -f file1 [file2 ...]
"""
    parser = argparse.ArgumentParser(
        prog=prog,
        usage=usage,
        description="Load, composite, resample, and save datasets.")
    parser.add_argument(
        '-v',
        '--verbose',
        dest='verbosity',
        action="count",
        default=0,
        help=
        'each occurrence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)'
    )
    parser.add_argument('-l',
                        '--log',
                        dest="log_fn",
                        default=None,
                        help="specify the log filename")
    parser.add_argument(
        '--progress',
        action='store_true',
        help="show processing progress bar (not recommended for logged output)"
    )
    parser.add_argument(
        '--num-workers',
        type=int,
        default=4,
        help="specify number of worker threads to use (default: 4)")
    parser.add_argument(
        '--match-resolution',
        dest='preserve_resolution',
        action='store_false',
        help="When using the 'native' resampler for composites, don't save data "
        "at its native resolution, use the resolution used to create the "
        "composite.")
    parser.add_argument('-w',
                        '--writers',
                        nargs='+',
                        help='writers to save datasets with')
    parser.add_argument("--list-products",
                        dest="list_products",
                        action="store_true",
                        help="List available reader products and exit")
    subgroups = add_scene_argument_groups(parser)
    subgroups += add_resample_argument_groups(parser)

    argv_without_help = [x for x in argv if x not in ["-h", "--help"]]
    args, remaining_args = parser.parse_known_args(argv_without_help)

    # get the logger if we know the readers and writers that will be used
    if args.reader is not None and args.writers is not None:
        glue_name = args.reader + "_" + "-".join(args.writers or [])
        LOG = logging.getLogger(glue_name)
    # add writer arguments
    if args.writers is not None:
        for writer in (args.writers or []):
            parser_func = WRITER_PARSER_FUNCTIONS.get(writer)
            if parser_func is None:
                continue
            subgroups += parser_func(parser)
    args = parser.parse_args(argv)

    if args.reader is None:
        parser.print_usage()
        parser.exit(
            1, "\nERROR: Reader must be provided (-r flag).\n"
            "Supported readers:\n\t{}\n".format('\n\t'.join(
                ['abi_l1b', 'ahi_hsd', 'hrit_ahi'])))
    if args.writers is None:
        parser.print_usage()
        parser.exit(
            1,
            "\nERROR: Writer must be provided (-w flag) with one or more writer.\n"
            "Supported writers:\n\t{}\n".format('\n\t'.join(['geotiff'])))

    def _args_to_dict(group_actions):
        return {
            ga.dest: getattr(args, ga.dest)
            for ga in group_actions if hasattr(args, ga.dest)
        }

    scene_args = _args_to_dict(subgroups[0]._group_actions)
    load_args = _args_to_dict(subgroups[1]._group_actions)
    resample_args = _args_to_dict(subgroups[2]._group_actions)
    writer_args = {}
    for idx, writer in enumerate(args.writers):
        sgrp1, sgrp2 = subgroups[3 + idx * 2:5 + idx * 2]
        wargs = _args_to_dict(sgrp1._group_actions)
        if sgrp2 is not None:
            wargs.update(_args_to_dict(sgrp2._group_actions))
        writer_args[writer] = wargs
        # get default output filename
        if 'filename' in wargs and wargs['filename'] is None:
            wargs['filename'] = get_default_output_filename(
                args.reader, writer)

    if not args.filenames:
        parser.print_usage()
        parser.exit(1, "\nERROR: No data files provided (-f flag)\n")

    # Prepare logging
    rename_log = False
    if args.log_fn is None:
        rename_log = True
        args.log_fn = glue_name + "_fail.log"
    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    setup_logging(console_level=levels[min(3, args.verbosity)],
                  log_filename=args.log_fn)
    logging.getLogger('rasterio').setLevel(levels[min(2, args.verbosity)])
    sys.excepthook = create_exc_handler(LOG.name)
    if levels[min(3, args.verbosity)] > logging.DEBUG:
        import warnings
        warnings.filterwarnings("ignore")
    LOG.debug("Starting script with arguments: %s", " ".join(sys.argv))

    # Set up dask and the number of workers
    if args.num_workers:
        from multiprocessing.pool import ThreadPool
        dask.config.set(pool=ThreadPool(args.num_workers))

    # Parse provided files and search for files if provided directories
    scene_args['filenames'] = get_input_files(scene_args['filenames'])
    # Create a Scene, analyze the provided files
    LOG.info("Sorting and reading input files...")
    try:
        scn = Scene(**scene_args)
    except ValueError as e:
        LOG.error(
            "{} | Enable debug message (-vvv) or see log file for details.".
            format(str(e)))
        LOG.debug("Further error information: ", exc_info=True)
        return -1
    except OSError:
        LOG.error(
            "Could not open files. Enable debug message (-vvv) or see log file for details."
        )
        LOG.debug("Further error information: ", exc_info=True)
        return -1

    if args.list_products:
        print("\n".join(sorted(scn.available_dataset_names(composites=True))))
        return 0

    # Rename the log file
    if rename_log:
        rename_log_file(glue_name +
                        scn.attrs['start_time'].strftime("_%Y%m%d_%H%M%S.log"))

    # Load the actual data arrays and metadata (lazy loaded as dask arrays)
    if load_args['products'] is None:
        try:
            reader_mod = importlib.import_module('polar2grid.readers.' +
                                                 scene_args['reader'])
            load_args['products'] = reader_mod.DEFAULT_PRODUCTS
            LOG.info("Using default product list: {}".format(
                load_args['products']))
        except (ImportError, AttributeError):
            LOG.error(
                "No default products list set, please specify with `--products`."
            )
            return -1

    LOG.info("Loading product metadata from files...")
    scn.load(load_args['products'])

    resample_kwargs = resample_args.copy()
    areas_to_resample = resample_kwargs.pop('grids')
    grid_configs = resample_kwargs.pop('grid_configs')
    resampler = resample_kwargs.pop('resampler')

    if areas_to_resample is None and resampler in [None, 'native']:
        # no areas specified
        areas_to_resample = ['MAX']
    elif areas_to_resample is None:
        raise ValueError(
            "Resampling method specified (--method) without any destination grid/area (-g flag)."
        )
    elif not areas_to_resample:
        # they don't want any resampling (they used '-g' with no args)
        areas_to_resample = [None]

    p2g_grid_configs = [x for x in grid_configs if x.endswith('.conf')]
    pyresample_area_configs = [
        x for x in grid_configs if not x.endswith('.conf')
    ]
    if not grid_configs or p2g_grid_configs:
        # if we were given p2g grid configs or we weren't given any to choose from
        from polar2grid.grids import GridManager
        grid_manager = GridManager(*p2g_grid_configs)
    else:
        grid_manager = {}

    if pyresample_area_configs:
        from pyresample.utils import parse_area_file
        custom_areas = parse_area_file(pyresample_area_configs)
        custom_areas = {x.area_id: x for x in custom_areas}
    else:
        custom_areas = {}

    ll_bbox = resample_kwargs.pop('ll_bbox')
    if ll_bbox:
        scn = scn.crop(ll_bbox=ll_bbox)

    wishlist = scn.wishlist.copy()
    preserve_resolution = get_preserve_resolution(args, resampler,
                                                  areas_to_resample)
    if preserve_resolution:
        preserved_products = set(wishlist) & set(scn.datasets.keys())
        resampled_products = set(wishlist) - preserved_products

        # original native scene
        to_save = write_scene(scn, args.writers, writer_args,
                              preserved_products)
    else:
        preserved_products = set()
        resampled_products = set(wishlist)
        to_save = []

    LOG.debug(
        "Products to preserve resolution for: {}".format(preserved_products))
    LOG.debug(
        "Products to use new resolution for: {}".format(resampled_products))
    for area_name in areas_to_resample:
        if area_name is None:
            # no resampling
            area_def = None
        elif area_name == 'MAX':
            area_def = scn.max_area()
        elif area_name == 'MIN':
            area_def = scn.min_area()
        elif area_name in custom_areas:
            area_def = custom_areas[area_name]
        elif area_name in grid_manager:
            p2g_def = grid_manager[area_name]
            area_def = p2g_def.to_satpy_area()
            if isinstance(area_def, DynamicAreaDefinition
                          ) and p2g_def['cell_width'] is not None:
                area_def = area_def.freeze(
                    scn.max_area(),
                    resolution=(abs(p2g_def['cell_width']),
                                abs(p2g_def['cell_height'])))
        else:
            area_def = get_area_def(area_name)

        if resampler is None and area_def is not None:
            rs = 'native' if area_name in ['MIN', 'MAX'] or is_native_grid(
                area_def, scn.max_area()) else 'nearest'
            LOG.debug(
                "Setting default resampling to '{}' for grid '{}'".format(
                    rs, area_name))
        else:
            rs = resampler

        if area_def is not None:
            LOG.info("Resampling data to '%s'", area_name)
            new_scn = scn.resample(area_def, resampler=rs, **resample_kwargs)
        elif not preserve_resolution:
            # the user didn't want to resample to any areas
            # the user also requested that we don't preserve resolution
            # which means we have to save this Scene's datasets
            # because they won't be saved
            new_scn = scn

        to_save = write_scene(new_scn,
                              args.writers,
                              writer_args,
                              resampled_products,
                              to_save=to_save)

    if args.progress:
        pbar = ProgressBar()
        pbar.register()

    LOG.info("Computing products and saving data to writers...")
    compute_writer_results(to_save)
    LOG.info("SUCCESS")
    return 0
示例#26
0
def cli():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description="Packs the raw PNG images into TFRecords.")
    parser.add_argument("--raw-images",
                        type=str,
                        help="Path of the raw images",
                        default=rio.DEFAULT_IMAGES_BASE_PATH)
    parser.add_argument("--metadata",
                        type=str,
                        help="Path to the metadata directory",
                        default=rio.DEFAULT_METADATA_BASE_PATH)
    parser.add_argument(
        "--num-workers",
        type=int,
        default=None,
        help=
        "Number of workers to be writing TFRecords. Defaults to number of cores."
    )
    parser.add_argument(
        "--random-seeds",
        type=int,
        nargs='+',
        default=[42],
        help=
        "The seed used to make the sorting determistic. Embedded in the dir name to allow multiple folds to be created."
    )
    parser.add_argument(
        "--sites-per-tfrecord",
        type=int,
        default=1500,
        help=
        "Only used with the random strategy, indicates how many site images you want in a single TFRecord"
    )
    parser.add_argument("--strategies",
                        nargs='+',
                        choices=VALID_STRATEGIES,
                        default=['random', 'by_exp_plate_site'],
                        help="""What strategies to use to pack up the records:
\t`random` - Randomly partitions each dataset into multiple TFRecords.
\t`by_exp_plate_site` - Groups by experiment, plate, and packs each site into individual TFRecords.
                        """)
    parser.add_argument(
        "--dest-path",
        type=str,
        default="./tfrecords",
        help="Destination directory of where to write the tfrecords")
    parser.add_argument("--runner",
                        type=str,
                        default="dask",
                        choices={'dask', 'dataflow'},
                        help="Specify one of DirectRunner, dataflow, or dask")
    parser.add_argument("--project",
                        type=str,
                        default=None,
                        help="If using dataflow, the project to bill")
    args = parser.parse_args()
    if args.runner == 'dataflow':
        if not args.project:
            raise ValueError(
                'When using dataflow, you need to specify project')

    metadata_df = rio.combine_metadata(args.metadata)
    if args.runner == 'dask':
        from dask.diagnostics import ProgressBar
        ProgressBar().register()

    pack_tfrecords(images_path=args.raw_images,
                   metadata_df=metadata_df,
                   dest_path=args.dest_path,
                   strategies=args.strategies,
                   sites_per_tfrecord=args.sites_per_tfrecord,
                   random_seeds=args.random_seeds,
                   num_workers=args.num_workers,
                   runner=args.runner,
                   project=args.project)
示例#27
0
def nnet_tuning(n_layers=[1, 2, 1],
                layer_size=[20, 21, 1],
                k=5,
                train_data_path='../data/training_data.csv',
                save_model=False,
                tracking_uri="http://0.0.0.0:5000"):

    # Log the parameters with mlflow
    mlflow.log_param("n_layers", n_layers)
    mlflow.set_tag("layer_size", layer_size)

    # Set random seed for reproducibility
    np.random.seed(RANDOM_SEED)
    random.seed(RANDOM_SEED)

    # Get data shuffled and split into training and test sets
    mdr = MiningDataReader(path=train_data_path)
    (variable_names, X_train, X_test, y_train,
     y_test) = mdr.get_splitted_data()

    pipeline = Pipeline(
        steps=[('scaling', StandardScaler()
                ), ('regression', MLPRegressor(random_state=RANDOM_SEED))])

    ### TRAINING ###
    ################

    # Generate all combinations for number of layers and layer size
    neurons_per_layer = tuple(
        np.arange(layer_size[0], layer_size[1], layer_size[2]))
    hls_values = []
    for layers_num in np.arange(n_layers[0], n_layers[1], n_layers[2]):
        hls_values.append([
            x for x in itertools.product(neurons_per_layer, repeat=layers_num)
        ])

    # Flatten the list
    hls_values = [item for sublist in hls_values for item in sublist]

    # Generate grid search for hyperparam tuning
    hyperparams = {}
    hyperparams['regression__hidden_layer_sizes'] = hls_values

    print("Training started...\n")

    # Create an instance of Random Forest Regressor and fit the data for the grid parameters using all processors
    modelCV = GridSearchCV(estimator=pipeline,
                           param_grid=hyperparams,
                           cv=k,
                           scoring='neg_mean_squared_error',
                           n_jobs=-1)

    with ProgressBar():
        modelCV.fit(X_train, y_train)

    # Iterate over the results storing training error for each hyperparameter combination
    results = modelCV.cv_results_
    param_list, training_err_list, training_dev_list = [], [], []
    for i in range(len(results['params'])):
        param = results['params'][i]
        score = (-1) * results['mean_test_score'][i]  # NEGATIVE MSE
        std = results['std_test_score'][i]
        param_list.append(param)
        training_err_list.append(score)
        training_dev_list.append(std)

    print(
        f"\nBest parameter set found for the training set:\n{modelCV.best_params_}"
    )

    # Store the index of the best combination
    best_index = param_list.index(modelCV.best_params_)

    # Get the best values for hyperparams
    best_hls = modelCV.best_params_['regression__hidden_layer_sizes']

    print("\nTraining finished. Evaluating model...\n")

    ### EVALUATION ###
    ##################

    # Criteria is hidden_layer_sizes
    criteria = 'hidden_layer_sizes'
    mlflow.set_tag("criteria", criteria)
    param_values = hls_values

    # Predict test data variying criteria param and evaluate the models
    training_err_by_criteria, training_dev_by_criteria, test_err_list = [], [], []
    rmse_score, mae_score, r2_score = -1, -1, -1
    feature_names, feature_importances = [], []
    for param_value in tqdm(param_values):
        model = Pipeline(steps=[('scaler', StandardScaler()),
                                ('regression',
                                 MLPRegressor(hidden_layer_sizes=param_value,
                                              random_state=RANDOM_SEED))])
        param = {'regression__hidden_layer_sizes': param_value}

        # Fit model and evaluate results
        model.fit(X_train, y_train)
        prediction = model.predict(X_test)
        index = param_list.index(param)
        training_err = training_err_list[index]
        training_dev = training_dev_list[index]
        (training_mse, test_mse, rmse, mae,
         r2) = get_test_metrics(training_err, y_test, prediction)
        # Store metrics
        training_err_by_criteria.append(training_mse)
        training_dev_by_criteria.append(training_dev)
        test_err_list.append(test_mse)
        # Set aditional metrics for the best combination
        if index == best_index:
            rmse_score = rmse
            mae_score = mae
            r2_score = r2

    # Generate the plots
    empty_img_folder()
    plot_errors(criteria, param_values, training_err_by_criteria,
                training_dev_by_criteria, test_err_list)

    # Once hyperparameters are selected, train and save the best model
    if save_model:
        print(
            "\nEvaluation finished. Training final model with train + test data with the best hyperparameters..."
        )
        final_model = Pipeline(
            steps=[('scaler', StandardScaler()),
                   ('regression',
                    MLPRegressor(hidden_layer_sizes=param_list[best_index]
                                 ['regression__hidden_layer_sizes']))])

        # Train the best model with all the data (training + test)
        full_X = np.vstack((X_train, X_test))
        full_y = np.concatenate((y_train, y_test))
        final_model.fit(full_X, full_y)

        # Log plots and model with mlflow
        mlflow.log_artifacts('./img')
        mlflow.sklearn.log_model(final_model, 'model')

    # Log results with mlflow
    mlflow.log_metric("train_mse", training_err_list[best_index])
    mlflow.log_metric("test_mse", min(test_err_list))
    mlflow.log_metric("rmse", rmse_score)
    mlflow.log_metric("mae", mae_score)
    mlflow.log_metric("r2", r2_score)
    mlflow.set_tag("best_params", param_list[best_index])

    # Output the results
    print(f'''
-----------------------------------------------------------------------------------------------------------------------
RESULTS
-----------------------------------------------------------------------------------------------------------------------
Best params: {param_list[best_index]}
Training MSE: {training_err_list[best_index]}
Test MSE: {min(test_err_list)}
RMSE: {rmse_score}
MAE: {mae_score}
R2: {r2_score}
-----------------------------------------------------------------------------------------------------------------------
''')
示例#28
0
# Absolute paths.
FILES = sorted(glob.glob(input_folder + "*." + extension))

# Read all data.
if extension == "grb":  # Grib files.
    DS = xr.open_mfdataset(FILES, engine="cfgrib")

else:  # Netcdf files
    DS = xr.open_mfdataset(FILES)

# Make sure data is time ordered.
DS = DS.sortby("time")

# Load data.
print("\n>>> Loading data into memory ...")
with ProgressBar():
    DS = DS.sortby("time")

# Extract data array.
DA = getattr(DS, code)

# Calculate linear trends and its parameters.
trends, parameters = alextools.linear_trends(DA)

# Sazonality of detrended data.
seasonal = alextools.climatology(DA)

# Anomalies.
anomalies = alextools.anomalies(DA)

# Put all results in the same xarray Dataset object.
示例#29
0
def with_and_without_transit(n, snapshots=None, branch_components=None):
    """
    Compute the with-and-without flows and losses.

    This function only works with the linear power so far and calculated the
    loss which *would* take place accoring to

    f²⋅r

    which is the loss for directed currents. If links are included their
    efficiency determines the loss.

    Parameters
    ----------
    n : pypsa.Network
        Network object with valid flow data.
    snapshots : pd.Index or list
        Snapshots for which the flows and losses are calculated. Thye must be
        a subset of n.snapshots. The default is None, which results
        in n.snapshots.
    branch_components : list
        Components for which the allocation should be calculated.
        The default is None, which results in n.passive_branch_components.

    Returns
    -------
    xarray.Dataset
        Resulting loss allocation of dimension {branch, country, snapshot} with
        variables [flow_with, loss_with, flow_without, loss_without].

    """
    branch_components = check_passive_branch_comps(branch_components, n)
    snapshots = check_snapshots(snapshots, n)
    regions = pd.Index(n.buses.country.unique(), name='country')
    branches = n.branches().loc[branch_components]
    f = network_flow(n, snapshots, branch_components)

    def regional_with_and_withtout_flow(region):
        in_region_buses = n.buses.query('country == @region').index
        region_branches = branches.query('bus0 in @in_region_buses '
                                         'or bus1 in @in_region_buses')
        buses_i = (pd.Index(region_branches.bus0.unique())
                   | pd.Index(region_branches.bus1.unique()) | in_region_buses)
        vicinity_buses = buses_i.difference(in_region_buses)
        branches_i = region_branches.index

        K = Incidence(n, branch_components).loc[buses_i]
        # create regional injection pattern with nodal injection at the border
        # accounting for the cross border flow
        p = (K @ f)
        # p.loc[in_region_buses] ==
        #     network_injection(n, snapshots).loc[snapshots, in_region_buses].T

        # modified injection pattern without transition
        im = upper(p.loc[vicinity_buses])
        ex = lower(p.loc[vicinity_buses])

        largerImport_b = im.sum('bus') > -ex.sum('bus')
        scaleImport = (im.sum('bus') + ex.sum('bus')) / im.sum('bus')
        scaleExport = (im.sum('bus') + ex.sum('bus')) / ex.sum('bus')
        netImOrEx = (im * scaleImport).where(largerImport_b,
                                             (ex * scaleExport))
        p_wo = xr.concat([p.loc[in_region_buses], netImOrEx], dim='bus')\
                 .reindex(bus=buses_i).fillna(0)

        if 'Link' in branch_components:
            H = xr.concat((PTDF(n, branch_components, snapshot=sn)
                           for sn in snapshots), dim='snapshot')\
                  .sel(branch=branches_i)
            # f == H @ p
        else:
            H = PTDF(n, branch_components).sel(branch=branches_i)
        f_wo = H.reindex(bus=buses_i).dot(p_wo, 'bus')

        res = Dataset({'flow_with_transit': f.sel(branch=branches_i),
                       'flow_without_transit': f_wo})\
            .assign_coords(country=region)
        return res.assign(transit_flow=res.flow_with_transit -
                          res.flow_without_transit)

    progress = ProgressBar()
    flows = xr.concat(
        (regional_with_and_withtout_flow(r) for r in progress(regions)),
        dim='country')
    comps = flows.get_index('branch').unique('component')
    loss = xr.concat(
        (flows.sel(component=c)**2 * DataArray(n.df(c).r_pu, dims='branch_i')
         if c in n.passive_branch_components else flows.sel(component=c) *
         DataArray(n.df(c).efficiency, dims='branch_i') for c in comps),
        dim=comps).stack(branch=['component', 'branch_i']).rename_vars(
            flow_with_transit='loss_with_transit',
            flow_without_transit='loss_without_transit',
            transit_flow='transit_flow_loss')
    return flows.merge(loss).assign_attrs(
        method='With-and-Without-Transit').fillna(0)
示例#30
0
    def save(
        self,
        savepath="./boutdata.nc",
        filetype="NETCDF4",
        variables=None,
        save_dtype=None,
        separate_vars=False,
        pre_load=False,
    ):
        """
        Save data variables to a netCDF file.

        Parameters
        ----------
        savepath : str, optional
        filetype : str, optional
        variables : list of str, optional
            Variables from the dataset to save. Default is to save all of them.
        separate_vars: bool, optional
            If this is true then every variable which depends on time (but not
            solely on time) will be saved into a different output file.
            The files are labelled by the name of the variable. Variables which
            don't meet this criterion will be present in every output file.
        pre_load : bool, optional
            When saving separate variables, will load each variable into memory
            before saving to file, which can be considerably faster.

        Examples
        --------
        If `separate_vars=True`, then multiple files will be created. These can
        all be opened and merged in one go using a call of the form:

        ds = xr.open_mfdataset('boutdata_*.nc', combine='nested', concat_dim=None)
        """

        if variables is None:
            # Save all variables
            to_save = self.data
        else:
            to_save = self.data[variables]

        if savepath == "./boutdata.nc":
            print(
                "Will save data into the current working directory, named as"
                " boutdata_[var].nc"
            )
        if savepath is None:
            raise ValueError("Must provide a path to which to save the data.")

        # make shallow copy of Dataset, so we do not modify the attributes of the data
        # when we change things to save
        to_save = to_save.copy()

        options = to_save.attrs.pop("options")
        if options:
            # TODO Convert Ben's options class to a (flattened) nested
            # dictionary then store it in ds.attrs?
            warnings.warn(
                "Haven't decided how to write options file back out yet - deleting "
                "options for now. To re-load this Dataset, pass the same inputfilepath "
                "to open_boutdataset when re-loading."
            )
        # Delete placeholders for options on each variable and coordinate
        for var in chain(to_save.data_vars, to_save.coords):
            try:
                del to_save[var].attrs["options"]
            except KeyError:
                pass

        # Store the metadata as individual attributes instead because
        # netCDF can't handle storing arbitrary objects in attrs
        def dict_to_attrs(obj, section):
            for key, value in obj.attrs.pop(section).items():
                obj.attrs[section + ":" + key] = value

        dict_to_attrs(to_save, "metadata")
        # Must do this for all variables and coordinates in dataset too
        for varname, da in chain(to_save.data_vars.items(), to_save.coords.items()):
            try:
                dict_to_attrs(da, "metadata")
            except KeyError:
                pass

        if "regions" in to_save.attrs:
            # Do not need to save regions as these can be reconstructed from the metadata
            try:
                del to_save.attrs["regions"]
            except KeyError:
                pass
            for var in chain(to_save.data_vars, to_save.coords):
                try:
                    del to_save[var].attrs["regions"]
                except KeyError:
                    pass

        if save_dtype is not None:
            encoding = {v: {"dtype": save_dtype} for v in to_save}
        else:
            encoding = None

        if separate_vars:
            # Save each major variable to a different netCDF file

            # Determine which variables are "major"
            # Defined as time-dependent, but not solely time-dependent
            major_vars, minor_vars = _find_major_vars(to_save)

            print("Will save the variables {} separately".format(str(major_vars)))

            # Save each one to separate file
            # TODO perform the save in parallel with save_mfdataset?
            for major_var in major_vars:
                # Group variables so that there is only one time-dependent
                # variable saved in each file
                minor_data = [to_save[minor_var] for minor_var in minor_vars]
                single_var_ds = xr.merge([to_save[major_var], *minor_data])

                # Add the attrs back on
                single_var_ds.attrs = to_save.attrs

                if pre_load:
                    single_var_ds.load()

                # Include the name of the variable in the name of the saved
                # file
                path = Path(savepath)
                var_savepath = (
                    str(path.parent / path.stem) + "_" + str(major_var) + path.suffix
                )
                if encoding is not None:
                    var_encoding = {major_var: encoding[major_var]}
                else:
                    var_encoding = None
                print("Saving " + major_var + " data...")
                with ProgressBar():
                    single_var_ds.to_netcdf(
                        path=str(var_savepath),
                        format=filetype,
                        compute=True,
                        encoding=var_encoding,
                    )

                # Force memory deallocation to limit RAM usage
                single_var_ds.close()
                del single_var_ds
                gc.collect()
        else:
            # Save data to a single file
            print("Saving data...")
            with ProgressBar():
                to_save.to_netcdf(
                    path=savepath, format=filetype, compute=True, encoding=encoding
                )

        return
示例#31
0
 def plink_free_gwas(self, validate=None, plot=False, causal_pos=None,
                     pca=None, stmd=False, high_precision=False,
                     high_precision_on_zero=False, **kwargs):
     """
     Compute the least square regression for a genotype in a phenotype. This
     assumes that the phenotype has been computed from a nearly independent
     set of variants to be accurate (I believe that that is the case for
     most programs but it is not "advertised")
     """
     seed = self.seed
     print('Performing GWAS\n    Using seed', seed)
     now = time.time()
     pfn = '%s_phenos.hdf5' % self.outpref
     gfn = '%s.geno.hdf5' % self.outpref
     pcn =  '%s.pcs' % self.outpref
     #daskpheno = da.from_array(self.pheno.PHENO.values).astype(np.float)
     #daskgeno = self.geno.rechunk({1: self.geno.shape[1]})
     if os.path.isfile(pfn):
         res, x_train, x_test, y_train, y_test = self.load_previous_run()
     else:
         np.random.seed(seed=seed)
         if validate is not None:
             opt = {'threads':self.threads, 'memory': self.max_memory}
             print('making the crossvalidation data')
             arrays = train_test_split(self.geno, self.pheno,
                                       random_state=seed,
                                       test_size=1 / validate)
             arrays = [arr.rechunk('auto') #estimate_chunks(arr.shape, **opt))
                       if isinstance(arr, da.core.Array) else arr
                       for arr in arrays]
             x_train, x_test, y_train, y_test = arrays
         else:
             x_train, x_test = self.geno, self.geno
             y_train, y_test = self.pheno, self.pheno
         #assert not da.isnan(x_train).any().compute(threads=self.threads)
         # write test and train IDs
         opts = dict(sep=' ', index=False, header=False)
         y_test.to_csv('%s_testIDs.txt' % self.outpref, **opts)
         y_train.to_csv('%s_trainIDs.txt' % self.outpref, **opts)
         # if isinstance(x_train, dask.array.core.Array):
         #     x_train = x_train.rechunk((x_train.shape[0], 1)).astype(
         #         np.float)
         if 'normalize' in kwargs:
             if kwargs['normalize']:
                 print('Normalizing train set to variance 1 and mean 0')
                 x_train = (x_train - x_train.mean(axis=0)) / x_train.std(
                     axis=0)
                 print('Normalizing test set to variance 1 and mean 0')
                 x_test = (x_test - x_test.mean(axis=0)) / x_test.std(axis=0
                                                                      )
         # Get apropriate function for linear regression
         func = self.nu_linregress if high_precision else self.st_mod \
             if stmd else self.linregress
         if not isinstance(y_train, da.core.Array):
             daskpheno = da.from_array(y_train.PHENO.values).astype(np.float
                                                                    )
         else:
             daskpheno = y_train
         if pca is not None:
             func = self.st_mod  # Force function to statsmodels
             print('Computing PCs with %d components' % pca)
             if os.path.isfile(pcn):
                 pcs = pd.read_csv(pcn, sep='\t')
             else:
                 # Perform PCA
                 with ProgressBar():
                     pcs = pd.DataFrame(self.do_pca(x_train, pca))  # Estimate PCAs
                 pcs.to_csv(pcn, sep='\t', index=False)
             if self.covs is not None:
                 covs_train = y_train.reindex(columns='iid').merge(
                     self.covs,  on=['iid'], how='left')
                 assert covs_train.shape[0] == y_train.shape[0]
                 covs = pd.concat([covs_train, pcs], axis=1)
             else:
                 pcs['fid'] = y_train.fid
                 pcs['iid'] = y_train.iid
                 covs = pcs
             combos = product((x_train[:, x] for x in range(
                 x_train.shape[1])), [daskpheno], [covs])
             delayed_results = [dask.delayed(func)(x, y, cov) for x, y, cov
                                in combos]
         else:
             combos = product((x_train[:, x] for x in range(
                 x_train.shape[1])), [daskpheno])
             delayed_results = [dask.delayed(func)(x, y) for x, y in combos]
         print('Performing regressions')
         with ProgressBar():
             r = dask.compute(*delayed_results, scheduler='threads',
                              num_workers=self.threads, cache=self.cache)
         gc.collect()
         try:
             res = pd.DataFrame.from_records(list(r), columns=r[0]._fields)
         except AttributeError:
             res = pd.DataFrame(r)
         assert res.shape[0] == self.bim.shape[0]
         # Combine mapping and gwas
         res = pd.concat((res, self.bim.reset_index()), axis=1)
         # check precision issues and re-run the association
         zeros = res[res.pvalue == 0.0]
         if not zeros.empty and not stmd and high_precision_on_zero:
             print('    Processing zeros with arbitrary precision')
             df = x_train.shape[0] - 2
             combos = product(df, zeros.rvalue.values)
             with ThreadPool(self.threads) as p:
                 results = p.starmap(self.high_precision_pvalue, combos)
             zero_res = np.array(*results)
             res.loc[res.pvalue == 0.0, 'pvalue'] = zero_res
             res['pvalue'] = [mp.mpf(z) for z in res.pvalue]
         self.p_values = res.pvalue.values
         # Make a manhatan plot
         if plot:
             self.manhattan_plot(causal_pos, alpha=plot)
         # write files
         res.to_csv('%s.gwas' % self.outpref, sep='\t', index=False)
         labels = ['/x_train', '/x_test']
         arrays = [x_train, x_test]
         hdf_opt = dict(table=True, mode='a', format="table")
         y_train.to_hdf(pfn, 'y_train', **hdf_opt)
         y_test.to_hdf(pfn, 'y_test', **hdf_opt)
         assert len(x_train.shape) == 2
         assert len(x_test.shape) == 2
         chunks = np.array([x_train.shape, x_test.shape])
         np.save('chunks.npy', chunks)
         data = dict(zip(labels, arrays))
         da.to_hdf5(gfn, data)
     print('GWAS DONE after %.2f seconds !!' % (time.time() - now))
     self.sum_stats = res
     self.x_train = x_train
     self.x_test = x_test
     self.y_train = y_train
     self.y_test = y_test
示例#32
0
def main(argv=sys.argv[1:]):
    global LOG
    from satpy import Scene
    from satpy.resample import get_area_def
    from satpy.writers import compute_writer_results
    from dask.diagnostics import ProgressBar
    from polar2grid.core.script_utils import (
        setup_logging, rename_log_file, create_exc_handler)
    import argparse
    prog = os.getenv('PROG_NAME', sys.argv[0])
    # "usage: " will be printed at the top of this:
    usage = """
    %(prog)s -h
see available products:
    %(prog)s -r <reader> -w <writer> --list-products -f file1 [file2 ...]
basic processing:
    %(prog)s -r <reader> -w <writer> [options] -f file1 [file2 ...]
basic processing with limited products:
    %(prog)s -r <reader> -w <writer> [options] -p prod1 prod2 -f file1 [file2 ...]
"""
    parser = argparse.ArgumentParser(prog=prog, usage=usage,
                                     description="Load, composite, resample, and save datasets.")
    parser.add_argument('-v', '--verbose', dest='verbosity', action="count", default=0,
                        help='each occurrence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)')
    parser.add_argument('-l', '--log', dest="log_fn", default=None,
                        help="specify the log filename")
    parser.add_argument('--progress', action='store_true',
                        help="show processing progress bar (not recommended for logged output)")
    parser.add_argument('--num-workers', type=int, default=4,
                        help="specify number of worker threads to use (default: 4)")
    parser.add_argument('--match-resolution', dest='preserve_resolution', action='store_false',
                        help="When using the 'native' resampler for composites, don't save data "
                             "at its native resolution, use the resolution used to create the "
                             "composite.")
    parser.add_argument('-w', '--writers', nargs='+',
                        help='writers to save datasets with')
    parser.add_argument("--list-products", dest="list_products", action="store_true",
                        help="List available reader products and exit")
    subgroups = add_scene_argument_groups(parser)
    subgroups += add_resample_argument_groups(parser)

    argv_without_help = [x for x in argv if x not in ["-h", "--help"]]
    args, remaining_args = parser.parse_known_args(argv_without_help)

    # get the logger if we know the readers and writers that will be used
    if args.reader is not None and args.writers is not None:
        glue_name = args.reader + "_" + "-".join(args.writers or [])
        LOG = logging.getLogger(glue_name)
    # add writer arguments
    if args.writers is not None:
        for writer in (args.writers or []):
            parser_func = WRITER_PARSER_FUNCTIONS.get(writer)
            if parser_func is None:
                continue
            subgroups += parser_func(parser)
    args = parser.parse_args(argv)

    if args.reader is None:
        parser.print_usage()
        parser.exit(1, "\nERROR: Reader must be provided (-r flag).\n"
                       "Supported readers:\n\t{}\n".format('\n\t'.join(['abi_l1b', 'ahi_hsd', 'hrit_ahi'])))
    if args.writers is None:
        parser.print_usage()
        parser.exit(1, "\nERROR: Writer must be provided (-w flag) with one or more writer.\n"
                       "Supported writers:\n\t{}\n".format('\n\t'.join(['geotiff'])))

    def _args_to_dict(group_actions):
        return {ga.dest: getattr(args, ga.dest) for ga in group_actions if hasattr(args, ga.dest)}
    scene_args = _args_to_dict(subgroups[0]._group_actions)
    load_args = _args_to_dict(subgroups[1]._group_actions)
    resample_args = _args_to_dict(subgroups[2]._group_actions)
    writer_args = {}
    for idx, writer in enumerate(args.writers):
        sgrp1, sgrp2 = subgroups[3 + idx * 2: 5 + idx * 2]
        wargs = _args_to_dict(sgrp1._group_actions)
        if sgrp2 is not None:
            wargs.update(_args_to_dict(sgrp2._group_actions))
        writer_args[writer] = wargs
        # get default output filename
        if 'filename' in wargs and wargs['filename'] is None:
            wargs['filename'] = get_default_output_filename(args.reader, writer)

    if not args.filenames:
        parser.print_usage()
        parser.exit(1, "\nERROR: No data files provided (-f flag)\n")

    # Prepare logging
    rename_log = False
    if args.log_fn is None:
        rename_log = True
        args.log_fn = glue_name + "_fail.log"
    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    setup_logging(console_level=levels[min(3, args.verbosity)], log_filename=args.log_fn)
    logging.getLogger('rasterio').setLevel(levels[min(2, args.verbosity)])
    sys.excepthook = create_exc_handler(LOG.name)
    if levels[min(3, args.verbosity)] > logging.DEBUG:
        import warnings
        warnings.filterwarnings("ignore")
    LOG.debug("Starting script with arguments: %s", " ".join(sys.argv))

    # Set up dask and the number of workers
    if args.num_workers:
        from multiprocessing.pool import ThreadPool
        dask.config.set(pool=ThreadPool(args.num_workers))

    # Parse provided files and search for files if provided directories
    scene_args['filenames'] = get_input_files(scene_args['filenames'])
    # Create a Scene, analyze the provided files
    LOG.info("Sorting and reading input files...")
    try:
        scn = Scene(**scene_args)
    except ValueError as e:
        LOG.error("{} | Enable debug message (-vvv) or see log file for details.".format(str(e)))
        LOG.debug("Further error information: ", exc_info=True)
        return -1
    except OSError:
        LOG.error("Could not open files. Enable debug message (-vvv) or see log file for details.")
        LOG.debug("Further error information: ", exc_info=True)
        return -1

    if args.list_products:
        print("\n".join(sorted(scn.available_dataset_names(composites=True))))
        return 0

    # Rename the log file
    if rename_log:
        rename_log_file(glue_name + scn.attrs['start_time'].strftime("_%Y%m%d_%H%M%S.log"))

    # Load the actual data arrays and metadata (lazy loaded as dask arrays)
    if load_args['products'] is None:
        try:
            reader_mod = importlib.import_module('polar2grid.readers.' + scene_args['reader'])
            load_args['products'] = reader_mod.DEFAULT_PRODUCTS
            LOG.info("Using default product list: {}".format(load_args['products']))
        except (ImportError, AttributeError):
            LOG.error("No default products list set, please specify with `--products`.")
            return -1

    LOG.info("Loading product metadata from files...")
    scn.load(load_args['products'])

    resample_kwargs = resample_args.copy()
    areas_to_resample = resample_kwargs.pop('grids')
    grid_configs = resample_kwargs.pop('grid_configs')
    resampler = resample_kwargs.pop('resampler')

    if areas_to_resample is None and resampler in [None, 'native']:
        # no areas specified
        areas_to_resample = ['MAX']
    elif areas_to_resample is None:
        raise ValueError("Resampling method specified (--method) without any destination grid/area (-g flag).")
    elif not areas_to_resample:
        # they don't want any resampling (they used '-g' with no args)
        areas_to_resample = [None]

    has_custom_grid = any(g not in ['MIN', 'MAX', None] for g in areas_to_resample)
    if has_custom_grid and resampler == 'native':
        LOG.error("Resampling method 'native' can only be used with 'MIN' or 'MAX' grids "
                  "(use 'nearest' method instead).")
        return -1

    p2g_grid_configs = [x for x in grid_configs if x.endswith('.conf')]
    pyresample_area_configs = [x for x in grid_configs if not x.endswith('.conf')]
    if not grid_configs or p2g_grid_configs:
        # if we were given p2g grid configs or we weren't given any to choose from
        from polar2grid.grids import GridManager
        grid_manager = GridManager(*p2g_grid_configs)
    else:
        grid_manager = {}

    if pyresample_area_configs:
        from pyresample.utils import parse_area_file
        custom_areas = parse_area_file(pyresample_area_configs)
        custom_areas = {x.area_id: x for x in custom_areas}
    else:
        custom_areas = {}

    ll_bbox = resample_kwargs.pop('ll_bbox')
    if ll_bbox:
        scn = scn.crop(ll_bbox=ll_bbox)

    wishlist = scn.wishlist.copy()
    preserve_resolution = get_preserve_resolution(args, resampler, areas_to_resample)
    if preserve_resolution:
        preserved_products = set(wishlist) & set(scn.datasets.keys())
        resampled_products = set(wishlist) - preserved_products

        # original native scene
        to_save = write_scene(scn, args.writers, writer_args, preserved_products)
    else:
        preserved_products = set()
        resampled_products = set(wishlist)
        to_save = []

    LOG.debug("Products to preserve resolution for: {}".format(preserved_products))
    LOG.debug("Products to use new resolution for: {}".format(resampled_products))
    for area_name in areas_to_resample:
        if area_name is None:
            # no resampling
            area_def = None
        elif area_name == 'MAX':
            area_def = scn.max_area()
        elif area_name == 'MIN':
            area_def = scn.min_area()
        elif area_name in custom_areas:
            area_def = custom_areas[area_name]
        elif area_name in grid_manager:
            from pyresample.geometry import DynamicAreaDefinition
            p2g_def = grid_manager[area_name]
            area_def = p2g_def.to_satpy_area()
            if isinstance(area_def, DynamicAreaDefinition) and p2g_def['cell_width'] is not None:
                area_def = area_def.freeze(scn.max_area(),
                                           resolution=(abs(p2g_def['cell_width']), abs(p2g_def['cell_height'])))
        else:
            area_def = get_area_def(area_name)

        if resampler is None and area_def is not None:
            rs = 'native' if area_name in ['MIN', 'MAX'] else 'nearest'
            LOG.debug("Setting default resampling to '{}' for grid '{}'".format(rs, area_name))
        else:
            rs = resampler

        if area_def is not None:
            LOG.info("Resampling data to '%s'", area_name)
            new_scn = scn.resample(area_def, resampler=rs, **resample_kwargs)
        elif not preserve_resolution:
            # the user didn't want to resample to any areas
            # the user also requested that we don't preserve resolution
            # which means we have to save this Scene's datasets
            # because they won't be saved
            new_scn = scn

        to_save = write_scene(new_scn, args.writers, writer_args, resampled_products, to_save=to_save)

    if args.progress:
        pbar = ProgressBar()
        pbar.register()

    LOG.info("Computing products and saving data to writers...")
    compute_writer_results(to_save)
    LOG.info("SUCCESS")
    return 0
示例#33
0
 def read_geno(bfile, freq_thresh, threads, check=False, max_memory=None,
               usable_snps=None, normalize=False, prefix='my_geno',
               thinning=None):
     chunks = (10000, 10000)
     # set Cache to protect memory spilling
     if max_memory is not None:
         available_memory = max_memory
     else:
         available_memory = psutil.virtual_memory().available
     cache = Chest(available_memory=available_memory)
     (bim, fam, g) = read_plink(bfile)  # read the files using pandas_plink
     g_std = da.nanstd(g, axis=1)
     if check:
         with ProgressBar():
             print('Removing invariant sites')
             idx = (g_std != 0).compute(cache=cache)
         g = g[idx, :]
         bim = bim[idx].copy().reset_index(drop=True)
         bim.i = bim.index.tolist()
         g_std = g_std[idx]
         del idx
         gc.collect()
     if usable_snps is not None:
         print('Restricting genotype to user specified variants')
         idx = sorted(bim[bim.snp.isin(usable_snps)].i.values)
         g = g[idx, :]
         bim = bim[bim.i.isin(idx)].copy().reset_index(drop=True)
         bim.i = bim.index.tolist()
     mafs = g.sum(axis=1) / (2 * g.shape[0]) if freq_thresh > 0 else None
     # Filter MAF
     if freq_thresh > 0:
         print('Filtering MAFs smaller than', freq_thresh)
         print('    Genotype matrix shape before', g.shape)
         assert freq_thresh < 0.5
         good = (mafs < (1 - float(freq_thresh))) & (mafs > float(
             freq_thresh))
         with ProgressBar():
             with dask.config.set(pool=ThreadPool(threads)):
                 good, mafs = dask.compute(good, mafs, cache=cache)
         g = g[good, :]
         print('    Genotype matrix shape after', g.shape)
         bim = bim[good]
         bim['mafs'] = mafs[good]
         bim.reset_index(drop=True, inplace=True)
         bim.i = bim.index.tolist()
         del good
         gc.collect()
     if not is_transposed(g, bim.shape[0], fam.shape[0]):
         g = g.T
     if normalize:
         print('Normalizing to mean 0 and sd 1')
         mean = da.nanmean(g.T, axis=1)
         g = (g - mean) / g_std
     if thinning is not None:
         print("Thinning genotype to %d variants" % thinning)
         idx = np.linspace(0, g.shape[1], num=thinning, dtype=int,
                           endpoint=False)
         bim = bim.reindex(index=idx)
         g = g[:, idx].rechunk('auto')
         bim['i'] = range(thinning)
     h5 = '%s.hdf5' % prefix
     if not os.path.isfile(h5):
         with ProgressBar(), h5py.File(h5) as hd5:
             print("Sending processed genotype to HDF5")
             chroms = sorted(bim.chrom.unique().astype(int))
             gr = bim.groupby('chrom')
             for chrom in chroms:
                 df = gr.get_group(str(chrom))
                 ch = g[:, df.i.values]
                 ch = ch.rechunk(estimate_chunks(ch.shape, threads,
                                                 memory=available_memory))
                 print('\tChromosome %s: %d individuals %d  variants' % (
                     chrom, ch.shape[0], ch.shape[1]))
                 hd5.create_dataset('/%s' % chrom,  data=ch.compute())
                 del ch
             del gr
     return g, h5, bim, fam #g, bim, fam
示例#34
0
    def read_geno(bfile, freq_thresh, threads, flip=False, check=False,
                  max_memory=None, usable_snps=None):
        """
        Read the plink bed fileset, restrict to a given frequency (optional,
        freq_thresh), flip the sequence to match the MAF (optional; flip), and
        check if constant variants present (optional; check)

        :param max_memory: Maximum allowed memory
        :param bfile: Prefix of the bed (plink) fileset
        :param freq_thresh: If greater than 0, limit MAF to at least freq_thresh
        :param threads: Number of threads to use in computation
        :param flip: Whether to check for flips and to fix the genotype file
        :param check: Whether to check for constant sites
        :return: Dataframes (bim, fam) and array corresponding to the bed fileset
        """
        # set Cache to protect memory spilling
        if max_memory is not None:
            available_memory = max_memory
        else:
            available_memory = psutil.virtual_memory().available
        cache = Chest(available_memory=available_memory)
        (bim, fam, g) = read_plink(bfile)  # read the files using pandas_plink
        # mask nans in the genotype matrix
        g = da.ma.masked_array(g, mask=da.isnan(g))
        m, n = g.shape  # get the dimensions of the genotype
        # remove invariant sites
        if check:
            g_std = g.std(axis=1)
            with ProgressBar(), dask.config.set(pool=ThreadPool(threads)):
                print('Removing invariant sites')
                idx = (g_std != 0).compute(cache=cache)
            g = g[idx, :]
            bim = bim[idx].copy().reset_index(drop=True)
            bim.i = bim.index.tolist()
            del g_std, idx
            gc.collect()
        if usable_snps is not None:
            print('Subsetting to specified SNPs')
            idx = bim[bim.snp.isin(usable_snps)].i.values
            idx.sort()
            g = g[idx, :]
            bim = bim[bim.i.isin(idx)].copy().reset_index(drop=True)
            bim.i = bim.index.tolist()
        # compute the mafs if required
        mafs = g.sum(axis=1) / (2 * n) if flip or freq_thresh > 0 else None
        if flip:
            print('Checking for flips')
            # check possible flips
            flips = np.zeros(bim.shape[0], dtype=bool)
            flips[np.where(mafs > 0.5)[0]] = True
            bim['flip'] = flips
            vec = np.zeros(flips.shape[0])
            vec[flips] = 2
            # perform the flipping
            g = abs(g.T - vec)
            del flips
            gc.collect()
        else:
            g = g.T
        # Filter MAF
        if freq_thresh > 0:
            print('Filtering MAFs smaller than', freq_thresh)
            print('    Genotype matrix shape before', g.shape)
            assert freq_thresh < 0.5
            good = (mafs < (1 - float(freq_thresh))) & (
                        mafs > float(freq_thresh))
            with ProgressBar():
                with dask.config.set(pool=ThreadPool(threads)):
                    good, mafs = dask.compute(good, mafs, cache=cache)
            g = g[:, good]
            print('    Genotype matrix shape after', g.shape)
            bim = bim[good]
            bim['mafs'] = mafs[good]
            del good
            gc.collect()
        bim = bim.reset_index(drop=True)  # Get the indices in order
        # Fix the i such that it matches the genotype indices
        bim['i'] = bim.index.tolist()
        # Get chunks apropriate with the number of threads
        print('Rechunking genotype file')
        print('\tChunks before', g.chunksize)
        g = g.rechunk(estimate_chunks(g.shape, threads, memory=available_memory
                                      ))
        print('\tChunks after', g.chunksize)
        del mafs
        gc.collect()
        return bim, fam, g
def predict_xr(model, input_xr, progress=True):
    """
    Utilise our wrappers to predict with a vanilla sklearn model.

    Last modified: September 2019

    Parameters
    ----------
    model : a scikit-learn model or compatible object
        Must have a predict() method that takes numpy arrays.
    input_xr : xarray.DataArray or xarray.Dataset
        Must have dimensions 'x' and 'y', may have dimension 'time'.

    Returns
    ----------
    output_xr : xarray.DataArray 
        An xarray.DataArray containing the prediction output from model 
        with input_xr as input. Has the same spatiotemporal structure 
        as input_xr.

    """
    def _get_class_ufunc(*args):
        """
        ufunc to apply classification to chunks of data
        """
        input_data_flattened = []
        for data in args:
            input_data_flattened.append(data.flatten())

        # Flatten array
        input_data_flattened = np.array(input_data_flattened).transpose()

        # Mask out no-data in input (not all classifiers can cope with
        # Inf or NaN values)
        input_data_flattened = np.where(np.isfinite(input_data_flattened),
                                        input_data_flattened, 0)

        # Actually apply the classification
        out_class = model.predict(input_data_flattened)

        # Mask out NaN or Inf values in results
        out_class = np.where(np.isfinite(out_class), out_class, 0)

        # Reshape when writing out
        return out_class.reshape(args[0].shape)

    def _get_class(*args):
        """
        Apply classification to xarray DataArrays.

        Uses dask to run chunks at a time in parallel

        """
        out = xr.apply_ufunc(_get_class_ufunc,
                             *args,
                             dask='parallelized',
                             output_dtypes=[np.uint8])

        return out

    # Set up a list of input data using variables passed in
    input_data = []

    for var_name in input_xr.data_vars:
        input_data.append(input_xr[var_name])

    # Run through classification. Need to expand and have a separate
    # dataframe for each variable so chunking in dask works.
    if progress:
        with ProgressBar():
            out_class = _get_class(*input_data).compute()
    else:
        out_class = _get_class(*input_data).compute()

    # Set the stacked coordinate to match the input
    output_xr = xr.DataArray(out_class, coords=input_xr.coords)

    return output_xr