Пример #1
0
    def _init_gui(self):

        close = Button(description=' Close', icon='trash', layout=_wlo)

        def _close(b):
            self.close()

        close.on_click(_close)

        frame = IntSlider(min=self._df.frame.astype(int).min(),
                          max=self._df.frame.astype(int).max(),
                          value=-1,
                          description='Frame',
                          layout=_wlo)

        cbut = Button(description=' Columns', icon='barcode')
        cols = self._df.columns.tolist()
        cols = SelectMultiple(options=cols, value=cols)

        def _cols(c):
            self.columns = c.new
            self._update_output()

        cols.observe(_cols, names='value')
        cfol = Folder(cbut, _ListDict([('cols', cols)]))

        rbut = Button(description=' Rows', icon='bars')
        rows = IntRangeSlider(min=self.indexes[0],
                              max=self.indexes[1],
                              value=[0, 50])

        def _rows(c):
            self.indexes = c.new
            print(self.indexes)
            self._update_output()

        rows.observe(_rows, names='value')

        rfol = Folder(rbut, _ListDict([('rows', rows)]))

        return _ListDict([('close', close), ('frame', frame), ('cols', cfol),
                          ('rows', rfol)])
Пример #2
0
    def _init_gui(self):

        close = Button(description=' Close', icon='trash', layout=_wlo)
        def _close(b): self.close()
        close.on_click(_close)

        frame = IntSlider(min=self._df.frame.astype(int).min(),
                          max=self._df.frame.astype(int).max(),
                          value=-1, description='Frame', layout=_wlo)

        cbut = Button(description=' Columns', icon='barcode')
        cols = self._df.columns.tolist()
        cols = SelectMultiple(options=cols, value=cols)

        def _cols(c):
            self.columns = c.new
            self._update_output()
        cols.observe(_cols, names='value')
        cfol = Folder(cbut, _ListDict([('cols', cols)]))

        rbut = Button(description=' Rows', icon='bars')
        rows = IntRangeSlider(min=self.indexes[0],
                              max=self.indexes[1],
                              value=[0, 50])
        def _rows(c):
            self.indexes = c.new
            print(self.indexes)
            self._update_output()
        rows.observe(_rows, names='value')

        rfol = Folder(rbut, _ListDict([('rows', rows)]))

        return _ListDict([('close', close),
                          ('frame', frame),
                          ('cols', cfol),
                          ('rows', rfol)])
Пример #3
0
def view():
    info = Label("Select a parcel to display.")

    temppath = config.get_value(['paths', 'temp'])
    datapath = config.get_value(['paths', 'data'])

    method = ToggleButtons(options=[('From local storage', 0),
                                    ('Remote to memory', 1)],
                           value=0,
                           description='',
                           disabled=True,
                           button_style='info',
                           tooltips=[
                               'View data that are stored on the local drive.',
                               'View data from memory.'
                           ])

    paths = RadioButtons(options=[
        (f"Temporary folder: '{temppath}'.", temppath),
        (f"Personal data folder: '{datapath}'.", datapath)
    ],
                         layout={'width': 'max-content'},
                         value=temppath)

    paths_box = Box([Label(value="Select folder:"), paths])

    tables_first = [
        f for f in os.listdir(paths.value)
        if os.path.isdir(os.path.join(paths.value, f))
    ]

    select_table = Dropdown(
        options=[f for f in tables_first if not f.startswith('.')],
        value=None,
        description='Select tabe:',
        disabled=False,
    )

    select_option = RadioButtons(options=[(f"Single parcel selection.", 1),
                                          (f"Multiple parcels selection.", 2)],
                                 disabled=True,
                                 layout={'width': 'max-content'})

    button_refresh = Button(layout=Layout(width='35px'), icon='fa-refresh')

    select_option_box = HBox([
        select_table, button_refresh,
        Label(value="Selection method:"), select_option
    ])

    selection_single = Dropdown(
        options=[],
        value=None,
        description='Select parcel:',
        disabled=False,
    )

    selection_multi = SelectMultiple(
        options=[],
        value=[],
        description='Select parcels:',
        disabled=False,
    )

    view_method = ToggleButtons(
        options=[],
        value=None,
        description='',
        disabled=False,
        button_style='info',
        tooltips=[],
    )

    rm_parcel = Button(value=False,
                       disabled=False,
                       button_style='danger',
                       tooltip='Delete parcel data.',
                       icon='trash',
                       layout=Layout(width='35px'))

    code_info = Label()
    single_box = HBox([selection_single, rm_parcel])
    select_box = Box([single_box])

    method_0 = VBox([info, paths_box, select_option_box, select_box])
    method_1 = VBox([])
    view_box = Output(layout=Layout(border='1px solid black'))
    method_out = Output()
    with method_out:
        display(method_0)

    def method_options(obj):
        with method_out:
            method_out.clear_output()
            if obj['new'] == 0:
                display(method_0)
            elif obj['new'] == 1:
                display(method_1)

    method.observe(method_options, 'value')

    @button_refresh.on_click
    def button_refresh_on_click(b):
        view_box.clear_output()
        tables_first = [
            f for f in os.listdir(paths.value)
            if os.path.isdir(os.path.join(paths.value, f))
        ]
        select_table.options = [
            f for f in tables_first if not f.startswith('.')
        ]
        if select_table.value is not None:
            parcels = f"{paths.value}{select_table.value}"
            parcels_list = [
                f for f in os.listdir(parcels) if not f.startswith('.')
            ]
            selection_single.options = parcels_list
            selection_multi.options = parcels_list
        else:
            selection_single.options = []
            selection_single.value = None
            selection_multi.options = []
            selection_multi.value = []

    @rm_parcel.on_click
    def rm_parcel_on_click(b):
        try:
            parcel_to_rm = f"{paths.value}{select_table.value}/{selection_single.value}"
            try:
                shutil.rmtree(f'{parcel_to_rm}')
            except Exception:
                pass
            try:
                os.remove(f'{parcel_to_rm}')
            except Exception:
                pass
#             print(f"The parce: '{selection_single.value}' is deleted.")
            parcels = f"{paths.value}{select_table.value}"
            parcels_list = [
                f for f in os.listdir(parcels) if not f.startswith('.')
            ]
            selection_single.options = parcels_list
            view_box.clear_output()
        except Exception:
            pass

    def on_select_option_change(change):
        if select_option.value == 1:
            select_box.children = [single_box]
        else:
            select_box.children = [selection_multi]

    select_option.observe(on_select_option_change, 'value')

    def on_datapath_change(change):
        tables = [
            f for f in os.listdir(paths.value)
            if os.path.isdir(os.path.join(paths.value, f))
        ]
        tables = [f for f in tables if not f.startswith('.')]
        select_table.options = tables

    paths.observe(on_datapath_change, 'value')

    def on_table_change(change):
        if select_table.value is not None:
            parcels = f"{paths.value}{select_table.value}"
            parcels_list = [
                f for f in os.listdir(parcels) if not f.startswith('.')
            ]
            selection_single.options = parcels_list
            selection_multi.options = parcels_list
        else:
            selection_single.options = []
            selection_single.value = None
            selection_multi.options = []
            selection_multi.value = []
            view_method.options = []

    select_table.observe(on_table_change, 'value')

    def on_selection_change(obj):
        code_info.value = "Select how to view the dataset."
        options_list = [('Get example code', 1)]
        if obj['new'] is not None:
            parceldata = f"{paths.value}{select_table.value}/{selection_single.value}"
            data_list = [
                f for f in os.listdir(parceldata) if not f.startswith('.')
            ]
            if any("time_series" in s for s in data_list):
                options_list.append(('Plot time series', 2))
            if any("chip_images" in s for s in data_list):
                options_list.append(('View images', 3))
            options_list.append(("Show on map", 4))
            if select_option.value == 2:
                options_list.append(('Comparison', 5))
            view_method.options = options_list
            view_method.value = None

    selection_single.observe(on_selection_change, 'value')
    selection_multi.observe(on_selection_change, 'value')

    def method_options(obj):
        view_box.clear_output()
        with view_box:
            if selection_single.value is None:
                with view_box:
                    print("Please select a parcel")

            elif select_option.value == 1:
                data_path = f'{paths.value}{select_table.value}/{selection_single.value}/'
                if obj['new'] == 1:
                    from src.ipycbm.ui_view import view_code
                    display(view_code.code(data_path))
                elif obj['new'] == 2:
                    from src.ipycbm.ui_view import view_time_series
                    display(view_time_series.time_series(data_path))
                elif obj['new'] == 3:
                    from src.ipycbm.ui_view import view_calendar
                    display(view_calendar.calendar(data_path))
                elif obj['new'] == 4:
                    from src.ipycbm.ui_view import view_map
                    display(view_map.widget_box(data_path))

            elif select_option.value == 2 and len(selection_multi.value) > 0:
                data_path = f'{paths.value}{select_table.value}/'
                data_paths = [
                    f'{data_path}{s}/' for s in selection_multi.value
                ]
                if obj['new'] == 1:
                    from src.ipycbm.ui_view import view_code
                    display(view_code.code(data_paths[0]))
                    pass
                elif obj['new'] == 2:
                    from src.ipycbm.ui_view import view_time_series
                    # display(view_time_series.time_series(data_list[0]))
                    pass
                elif obj['new'] == 3:
                    from src.ipycbm.ui_view import view_calendar
                    # display(view_chip_images.calendar(data_path))
                    pass
                elif obj['new'] == 4:
                    from src.ipycbm.ui_view import view_maps
                    display(view_maps.with_polygons(data_paths))

    selection_single.observe(method_options, 'value')
    selection_multi.observe(method_options, 'value')
    view_method.observe(method_options, 'value')

    notes_info = Label("Add a note for the parcel")
    notes_bt = Button(value=False,
                      description='Add note',
                      disabled=False,
                      button_style='info',
                      tooltip='Add a note.',
                      icon='sticky-note')
    notes_box = VBox([])

    @notes_bt.on_click
    def notes_bt_on_click(b):
        if notes_box.children == ():
            notes_box.children = [
                view_notes.notes(f"{paths.value}{select_table.value}/",
                                 select_table.value,
                                 selection_single.value.replace('parcel_', ''))
            ]
        else:
            notes_box.children = []

    wbox = VBox([
        method_out, code_info, view_method, view_box,
        HBox([notes_info, notes_bt]), notes_box
    ])

    return wbox
Пример #4
0
class base_processor_widget(VBox):
    """
        Summary:
            
    """
    def __init__(self, signal_components: dict, _type=""):
        """
            Summary:
                
        """
        self.type = _type

        # Create the list of signals available
        # Include an empty option
        self.signals = ["", *list(signal_components)]

        self.wsm_signals = SelectMultiple(options=self.signals,
                                          value=[""],
                                          description="Signals:",
                                          placeholder="Signals",
                                          disabled=False)

        # Components
        self.components = signal_components
        components = [""]
        for key in signal_components:
            components = [*components, *signal_components[key]]

        self.wsm_components = SelectMultipleOrdered(options=components,
                                                    value=[""],
                                                    description="Components:",
                                                    placeholder="Components",
                                                    disabled=False)

        def on_signal_change(change):
            components = [""]

            if "" in self.wsm_signals.value:
                for key in self.components:
                    components = [*components, *self.components[key]]
            else:
                for key in self.wsm_signals.value:
                    components = [*components, *self.components[key]]

            self.wsm_components.options = components

        self.wsm_signals.observe(on_signal_change, 'value')

        super().__init__([
            HTML(value=f"<B>Processor type: {self.type}</B>"),
            HBox([self.wsm_signals, self.wsm_components]),
            HTML(value="<B>Options :</B>")
        ],
                         layout=Layout(border='1px solid black'))

        self.options = []

    def add_options(self, options=None):
        """
        Summary:
            Add options related to a specific processor type.
        
        Arguments:
            options - list of widgets representing the options specific to 
            the preprocessor
            
        Returns:
            Nothing.
        """

        if options is not None:
            self.children = [*self.children, *options]

            self.options = options

    def dump(self) -> dict:
        """
        Summary:
            Build and return a dictiory descibing the preprocessor and its
            options.
        
        Arguments:
            None.

        Returns:
            Dictionary describing the preprocessor.            
        """
        out_dict = {"type": self.type}

        # Add the signal
        signals = list(self.children[1].children[0].value)

        # Check if there is at list one signal
        if (len(signals) == 1) and (signals[0] != ""):
            out_dict["signals"] = signals

        # Add the components
        components = list(self.wsm_components.ordered_value)

        # Check if there is at list one component
        if (len(components) >= 1) and (components[0] != ""):
            out_dict["components"] = components

        # Now add the options, if present
        for wd in self.options:
            key = wd.description

            # Eventually remove the last colon
            if key[-1] == ":":
                key = key[:-1]

            value = wd.value

            if value is None:
                continue

            if (type(value) is str) and (value == ""):
                continue

            if type(value) is tuple:
                value = list(value)

            out_dict[key] = value

        return out_dict

    def get_signal_components(self) -> dict:
        """
       Summary:
           Return the dictionary with signal and components at the output of 
           processor.         
       """

        out_dict = {}
        if "" in self.wsm_signals.value:
            signals = list(self.components.keys())
        else:
            signals = self.wsm_signals.value

        for signal in signals:
            if "" in self.wsm_components.value:
                out_dict[signal] = self.components[signal]
            else:
                out_dict[signal] = list(set(self.components[signal]) & \
                                        set(self.wsm_components.value))

        return out_dict

    def initialize(self, options):
        """
        Summary:
            Initialize the processor using a dictionary, which needs to 
            have the same format has that produced by the dump function.
        
        Arguments:
            options - dictionary with the options to initialize the processor

        Returns:
            Nothing.  
        """
        # Add the signal
        if "signals" in options:
            self.wsm_signals.value = options["signals"]
        else:
            self.wsm_signals.value = tuple([""])

        # Add the components
        if "components" in options:
            compo_list = []
            for component in options["components"]:
                if component in self.wsm_components.options:
                    compo_list.append(component)

            self.wsm_components.value = compo_list

        else:
            self.wsm_components.value = tuple([""])

        # Now add specific options, if present
        for wd in self.options:
            key = wd.description

            # Eventually remove the last colon
            if key[-1] == ":":
                key = key[:-1]

            if key in options:
                if isinstance(wd, Text) and isinstance(options[key], list):
                    wd.value = ", ".join(options[key])
                else:
                    wd.value = options[key]

        # Set the children of the widget
        self.children = [
            HTML(value=f"<B>Processor type: {self.type}</B>"),
            HBox([self.wsm_signals, self.wsm_components]),
            HTML(value="<B>Options :</B>"), *self.options
        ]
Пример #5
0
class MultiPlotABS(object):
    def __init__(self, file_list, max_lines=10):
        self.file_list_path = [pathlib.Path(i) for i in file_list]
        self.file_list = [pathlib.Path(i).parts[-1] for i in file_list]
        self.max_lines = max_lines

    def update_plot(self, event):
        options = event['owner'].options
        index_list = event['owner'].index

        for i in range(self.max_lines):
            line = self.ax.get_lines()[i]
            line.set_xdata([])
            line.set_ydata([])
            line.set_label('')

        self.text_output.clear_output()
        with self.text_output:
            lines = []
            for i, index in enumerate(index_list):
                file = self.file_list[index]
                file_path = self.file_list_path[index]

                if i >= self.max_lines:
                    print('==> Skipping {}. Can only show {} at a time...'.
                          format(file, self.max_lines))
                    continue

                line = self.lines[i]
                print('--> Showing {}'.format(file))
                df, config = readABS(file_path)
                line.set_xdata(df['q'].values)
                line.set_ydata(df['I'].values)
                line.set_label(file)
                lines.append(file)
        self.ax.relim()
        self.ax.autoscale_view()
        self.ax.legend(lines)

    def build_widget(self):

        # init interactive plot
        self.plot_output = Output(layout={'width': '600px'})

        self.text_output = Output()
        self.select = SelectMultiple(options=self.file_list,
                                     layout={'width': '400px'},
                                     rows=20)
        self.select.observe(self.update_plot)

        VB = VBox([self.select], layout={'align_self': 'center'})
        HB = HBox([VB, self.plot_output])
        widget = VBox([HB, self.text_output])

        return widget

    def init_plot(self):
        with self.plot_output:
            fig, ax = plt.subplots()
            self.fig = fig
            self.ax = ax

        # init lines
        colors = sns.palettes.color_palette(palette='bright',
                                            n_colors=self.max_lines)
        self.lines = []
        for i in range(self.max_lines):
            line = plt.matplotlib.lines.Line2D([], [])
            line.set(color=colors[i], marker='o', ms=3, ls='None')
            ax.add_line(line)
            self.lines.append(line)
        ax.set_xscale('log')
        ax.set_yscale('log')

    def run_widget(self):
        widget = self.build_widget()
        self.init_plot()
        return widget
Пример #6
0
class SelectMultipleTokensSelector(TokensSelector):
    def __init__(self,
                 tokens: pd.DataFrame,
                 token_column='l2_norm_token',
                 norms_columns=None):
        self._text_widget: Text = None
        self._tokens_widget: SelectMultiple = None
        self._token_to_index = None
        super().__init__(tokens, token_column, norms_columns)

    def display(self, tokens: pd.DataFrame) -> "TokensSelector":
        super().display(tokens)
        if self.tokens is not None:
            self._rehash_token_to_index()
        return self

    def _rehash_token_to_index(self):
        self._token_to_index = {
            w: i
            for i, w in enumerate(self.tokens.index.tolist())
        }

    def _create_widget(self) -> SelectMultiple:

        _tokens = list(self.tokens.index)
        _layout = Layout(width="200px")
        self._tokens_widget = SelectMultiple(options=_tokens,
                                             value=[],
                                             rows=30)
        self._tokens_widget.layout = _layout
        self._tokens_widget.observe(self._on_selection_changed, "value")

        self._text_widget = Text(description="")
        self._text_widget.layout = _layout
        self._text_widget.observe(self._on_filter_changed, "value")

        _widget = VBox(
            [HTML("<b>Filter</b>"), self._text_widget, self._tokens_widget])

        return _widget

    def _on_filter_changed(self, *_):
        _filter = self._text_widget.value.strip()
        if _filter == "":
            _options = self.tokens.index.tolist()
        else:
            _options = self.tokens[self.tokens.index.str.contains(
                _filter)].index.tolist()
        self._tokens_widget.value = [
            x for x in self._tokens_widget.value if x in _options
        ]
        self._tokens_widget.options = _options

    def get_selected_tokens(self) -> List[str]:

        return list(self._tokens_widget.value)

    def get_selected_indices(self) -> List[int]:

        return [self._token_to_index[w] for w in self._tokens_widget.value]

    def get_tokens_slice(self, start: int = 0, n: int = 0):

        return self._tokens_widget.options[start:start + n]

    def set_selected_indices(self, indices: List[int]):
        _options = self._tokens_widget.options
        self._tokens_widget.value = [_options[i] for i in indices]

    def __len__(self) -> int:
        return len(self._tokens_widget.options)

    def __getitem__(self, key) -> str:
        return self._tokens_widget.options[key]
Пример #7
0
def get():
    """Get the parcel's dataset for the given location or ids"""
    debug = False
    info = Label("1. Select the aoi to get parcel data.")

    values = config.read()
    ppoly_out = Output()
    progress = Output()

    def outlog(*text):
        with progress:
            print(*text)

    def outlog_poly(*text):
        with ppoly_out:
            print(*text)

    def aois_options():
        values = config.read()
        options = {}
        if values['set']['data_source'] == 'api':
            api_values = config.read('api_options.json')
            for aoi in api_values['aois']:
                options[(aoi.upper(), aoi)] = api_values['aois'][aoi]['years']
        elif values['set']['data_source'] == 'direct':
            values = config.read('api_options.json')
            for aoi in values['dataset']:
                options[(f"{aoi.upper()} ({aoi})", aoi)] = [aoi.split('_')[-1]]
        return options

    def aois_years():
        values = config.read()
        years = {}
        if values['set']['data_source'] == 'api':
            api_values = config.read('api_options.json')
            for aoi in api_values['aois']:
                years[aoi] = api_values['aois'][aoi]['years']
        elif values['set']['data_source'] == 'direct':
            values = config.read()
            for aoi in values['dataset']:
                years[aoi] = [aoi.split('_')[-1]]
        return years

    try:
        aois = Dropdown(
            options=tuple(aois_options()),
            value=values['set']['dataset'],
            description='AOI:',
        )
    except Exception:
        aois = Dropdown(
            options=tuple(aois_options()),
            description='AOI:',
        )

    def years_disabled():
        values = config.read()
        if values['set']['data_source'] == 'direct':
            return True
        else:
            return False

    year = Dropdown(
        options=next(iter(aois_options().values())),
        description='Year:',
        disabled=years_disabled(),
    )
    button_refresh = Button(layout=Layout(width='35px'), icon='fa-refresh')

    @button_refresh.on_click
    def button_refresh_on_click(b):
        values = config.read()
        if values['set']['data_source'] == 'api':
            from cbm.datas import api
            available_options = json.loads(api.get_options())
            try:
                api_options = normpath(
                    join(config.path_conf, 'api_options.json'))
                os.makedirs(dirname(api_options), exist_ok=True)
                with open(api_options, "w") as f:
                    json.dump(available_options, f, indent=4)
                outlog(f"File saved at: {api_options}")
            except Exception as err:
                outlog(f"Could not create the file 'api_options.json': {err}")

            outlog(f"The API options are updated.")
        aois.options = tuple(aois_options())
        year.options = aois_years()[aois.value]
        year.disabled = years_disabled()

    def table_options_change(change):
        api_values = config.read('api_options.json')
        id_examples = api_values['aois'][change.new]['id_examples']
        try:
            id_examples_label.value = ', '.join(str(x) for x in id_examples)
            year.options = aois_years()[change.new]
            year.disabled = years_disabled()
            pid.value = str(id_examples[0])
        except Exception:
            id_examples_label.value = ', '.join(str(x) for x in id_examples)
            aois.options = tuple(aois_options())
            year.options = aois_years()[aois.value]
            year.disabled = years_disabled()
            pid.value = str(id_examples[0])

    aois.observe(table_options_change, 'value')

    info_method = Label("2. Select a method to download parcel data.")

    method = ToggleButtons(
        options=[('Parcel ID', 2), ('Coordinates', 1), ('Map marker', 3),
                 ('Polygon', 4)],
        value=None,
        description='',
        disabled=False,
        button_style='info',
        tooltips=[
            'Enter lon lat', 'Enter parcel ID', 'Select a point on a map',
            'Get parcels id in a polygon'
        ],
    )

    plon = Text(value='5.664', placeholder='Add lon', description='Lon:')
    plat = Text(value='52.694', placeholder='Add lat', description='Lat:')
    wbox_lat_lot = VBox(children=[plat, plon])

    api_values = config.read('api_options.json')
    id_examples = api_values['aois'][aois.value]['id_examples']

    id_examples_label = Label(', '.join(str(x) for x in id_examples))
    info_pid = HBox(
        [Label("Multiple parcel ids can be added, e.g.: "), id_examples_label])

    pid = Textarea(
        value=str(id_examples[0]),
        placeholder='12345, 67890',
        description='Parcel(s) ID:',
    )

    wbox_pids = VBox(children=[info_pid, pid])

    bt_get_ids = Button(description="Find parcels",
                        disabled=False,
                        button_style='info',
                        tooltip='Find parcels within the polygon.',
                        icon='')

    get_ids_box = HBox(
        [bt_get_ids,
         Label("Find the parcels that are in the polygon.")])

    @bt_get_ids.on_click
    def bt_get_ids_on_click(b):
        with ppoly_out:
            try:
                # get_requests = data_source()
                ppoly_out.clear_output()
                polygon = get_maps.polygon_map.feature_collection['features'][
                    -1]['geometry']['coordinates'][0]
                polygon_str = '-'.join(
                    ['_'.join(map(str, c)) for c in polygon])
                outlog_poly(f"Geting parcel ids within the polygon...")
                polyids = parcel_info.by_polygon(aois.value, year.value,
                                                 polygon_str, ptype.value,
                                                 False, True)
                outlog_poly(
                    f"'{len(polyids['ogc_fid'])}' parcels where found:")
                outlog_poly(polyids['ogc_fid'])
                file = normpath(
                    join(config.get_value(['paths', 'temp']),
                         'pids_from_polygon.txt'))
                with open(file, "w") as text_file:
                    text_file.write('\n'.join(map(str, polyids['ogc_fid'])))
            except Exception as err:
                outlog("No parcel ids found:", err)

    method_out = Output(layout=Layout(border='1px solid black'))

    def method_options(obj):
        with method_out:
            method_out.clear_output()
            if obj['new'] == 1:
                display(wbox_lat_lot)
            elif obj['new'] == 2:
                display(wbox_pids)
            elif obj['new'] == 3:
                display(
                    get_maps.base_map(aois.value,
                                      config.get_value(['set',
                                                        'data_source'])))
            elif obj['new'] == 4:
                display(
                    VBox([
                        get_maps.polygon(
                            aois.value,
                            config.get_value(['set', 'data_source'])),
                        get_ids_box, ppoly_out
                    ]))

    method.observe(method_options, 'value')

    info_type = Label("3. Select datasets to download.")

    ptype = Text(value=None,
                 placeholder='(Optional) Parcel Type',
                 description='pType:',
                 disabled=False)

    table_options = HBox([aois, button_refresh, ptype, year])

    # ########### Time series options #########################################
    pts_bt = ToggleButton(
        value=False,
        description='Time series',
        button_style='success',  # success
        tooltip='Get parcel information',
        icon='toggle-off',
        layout=Layout(width='50%'))

    pts_bands = data_options.pts_bands()

    pts_tstype = SelectMultiple(
        options=[("Sentinel-2 Level 2A", 's2'),
                 ("S1 Backscattering Coefficients", 'bs'),
                 ("S1 6-day Coherence (20m)", 'c6')],
        value=['s2'],
        rows=3,
        description='TS type:',
        disabled=False,
    )

    pts_band = Dropdown(
        options=list(pts_bands['s2']),
        value='',
        description='Band:',
        disabled=False,
    )

    def pts_tstype_change(change):
        if len(pts_tstype.value) <= 1:
            pts_band.disabled = False
            try:
                pts_b = change.new[0]
                pts_band.options = pts_bands[pts_b]
            except Exception:
                pass
        else:
            pts_band.value = ''
            pts_band.disabled = True

    pts_tstype.observe(pts_tstype_change, 'value')

    pts_options = VBox(children=[pts_tstype, pts_band])

    # ########### Chip images options #########################################
    pci_bt = ToggleButton(value=False,
                          description='Chip images',
                          disabled=False,
                          button_style='success',
                          tooltip='Get parcel information',
                          icon='toggle-off',
                          layout=Layout(width='50%'))

    pci_start_date = DatePicker(
        value=datetime.date(2020, 6, 1),
        description='Start Date',
    )

    pci_end_date = DatePicker(
        value=datetime.date(2020, 6, 30),
        description='End Date',
    )

    pci_plevel = RadioButtons(
        options=['LEVEL2A', 'LEVEL1C'],
        value='LEVEL2A',
        description='Proces. level:',  # Processing level
        disabled=False,
        layout=Layout(width='50%'))

    pci_chipsize = IntSlider(value=640,
                             min=100,
                             max=5120,
                             step=10,
                             description='Chip size:',
                             disabled=False,
                             continuous_update=False,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')

    pci_bands = data_options.pci_bands()

    pci_satellite = RadioButtons(options=list(pci_bands),
                                 value='Sentinel 2',
                                 disabled=True,
                                 layout=Layout(width='100px'))

    pci_band = SelectMultiple(options=list(pci_bands['Sentinel 2']),
                              value=['B04'],
                              rows=11,
                              description='Band:',
                              disabled=False)

    sats_plevel = HBox([pci_satellite, pci_plevel])

    def on_sat_change(change):
        sat = change.new
        pci_band.options = pci_bands[sat]

    pci_satellite.observe(on_sat_change, 'value')

    pci_options = VBox(children=[
        pci_start_date, pci_end_date, sats_plevel, pci_chipsize, pci_band
    ])

    # ########### General options #############################################
    pts_wbox = VBox(children=[])
    pci_wbox = VBox(children=[])

    def pts_observe(button):
        if button['new']:
            pts_bt.icon = 'toggle-on'
            pts_wbox.children = [pts_options]
        else:
            pts_bt.icon = 'toggle-off'
            pts_wbox.children = []

    def pci_observe(button):
        if button['new']:
            pci_bt.icon = 'toggle-on'
            pci_wbox.children = [pci_options]
        else:
            pci_bt.icon = 'toggle-off'
            pci_wbox.children = []

    pts_bt.observe(pts_observe, names='value')
    pci_bt.observe(pci_observe, names='value')

    pts = VBox(children=[pts_bt, pts_wbox], layout=Layout(width='40%'))
    pci = VBox(children=[pci_bt, pci_wbox], layout=Layout(width='40%'))

    data_types = HBox(children=[pts, pci])

    info_get = Label("4. Download the selected data.")

    bt_get = Button(description='Download',
                    button_style='warning',
                    tooltip='Send the request',
                    icon='download')

    path_temp = config.get_value(['paths', 'temp'])
    path_data = config.get_value(['paths', 'data'])

    info_paths = HTML("".join([
        "<style>div.c {line-height: 1.1;}</style>",
        "<div class='c';>By default data will be stored in the temp folder ",
        f"({path_temp}), you will be asked to empty the temp folder each time ",
        "you start the notebook.<br>In your personal data folder ",
        f"({path_data}) you can permanently store the data.</div>"
    ]))

    paths = RadioButtons(options=[
        (f"Temporary folder: '{path_temp}'.", path_temp),
        (f"Personal data folder: '{path_data}'.", path_data)
    ],
                         layout={'width': 'max-content'},
                         value=path_temp)

    paths_box = Box([Label(value="Select folder:"), paths])

    def file_len(fname):
        with open(fname) as f:
            for i, l in enumerate(f):
                pass
        return i + 1

    def get_data(parcel):
        get_requests = data_source()
        pid = str(parcel['pid'][0])
        source = config.get_value(['set', 'data_source'])
        if source == 'api':
            datapath = normpath(join(paths.value, aois.value, year.value, pid))
        elif source == 'direct':
            dataset = config.get_value(['set', 'dataset'])
            datapath = normpath(join(paths.value, dataset, pid))
        file_pinf = normpath(join(datapath, 'info.json'))
        os.makedirs(dirname(file_pinf), exist_ok=True)
        with open(file_pinf, "w") as f:
            json.dump(parcel, f)
        outlog(f"File saved at: {file_pinf}")

        if pts_bt.value is True:
            outlog(f"Getting time series for parcel: '{pid}',",
                   f"({pts_tstype.value} {pts_band.value}).")
            for pts in pts_tstype.value:
                ts = time_series.by_pid(aois.value, year.value, pid, pts,
                                        ptype.value, pts_band.value)
                band = ''
                if pts_band.value != '':
                    band = f"_{pts_band.value}"
                file_ts = normpath(
                    join(datapath, f'time_series_{pts}{band}.csv'))
                if isinstance(ts, pd.DataFrame):
                    ts.to_csv(file_ts, index=True, header=True)
                elif isinstance(ts, dict):
                    os.makedirs(os.path.dirname(file_ts), exist_ok=True)
                    df = pd.DataFrame.from_dict(ts, orient='columns')
                    df.to_csv(file_ts, index=True, header=True)
            outlog("TS Files are saved.")
        if pci_bt.value is True:
            files_pci = normpath(join(datapath, 'chip_images'))
            outlog(f"Getting '{pci_band.value}' chip images for parcel: {pid}")
            with progress:
                get_requests.rcbl(parcel, pci_start_date.value,
                                  pci_end_date.value, pci_band.value,
                                  pci_chipsize.value, files_pci)
            filet = normpath(
                join(datapath, 'chip_images',
                     f'images_list.{pci_band.value[0]}.csv'))
            if file_len(filet) > 1:
                outlog(
                    f"Completed, all GeoTIFFs for bands '{pci_band.value}' are ",
                    f"downloaded in the folder: '{datapath}/chip_images'")
            else:
                outlog(
                    "No files where downloaded, please check your configurations"
                )

    def get_from_location(lon, lat):
        outlog(f"Finding parcel information for coordinates: {lon}, {lat}")
        parcel = parcel_info.by_location(aois.value, year.value, lon, lat,
                                         ptype.value, True, False, debug)
        pid = str(parcel['pid'][0])
        outlog(f"The parcel '{pid}' was found at this location.")
        try:
            get_data(parcel)
        except Exception as err:
            print(err)

    def get_from_id(pids):
        outlog(f"Getting parcels information for: '{pids}'")
        for pid in pids:
            try:
                parcel = parcel_info.by_pid(aois.value, year.value, pid,
                                            ptype.value, True, False, debug)
                get_data(parcel)
            except Exception as err:
                print(err)

    @bt_get.on_click
    def bt_get_on_click(b):
        progress.clear_output()
        if method.value == 1:
            try:
                with progress:
                    lon, lat = plon.value, plat.value
                    get_from_location(lon, lat)
            except Exception as err:
                outlog("Could not get parcel information for location",
                       f"'{lon}', '{lat}': {err}")

        elif method.value == 2:
            try:
                with progress:
                    pids = pid.value.replace(" ", "").split(",")
                    get_from_id(pids)
            except Exception as err:
                outlog(f"Could not get parcel information: {err}")

        elif method.value == 3:
            try:
                marker = get_maps.base_map.map_marker
                lon = str(round(marker.location[1], 2))
                lat = str(round(marker.location[0], 2))
                get_from_location(lon, lat)
            except Exception as err:
                outlog(f"Could not get parcel information: {err}")
        elif method.value == 4:
            try:
                plimit = int(values['set']['plimit'])
                file = normpath(
                    join(config.get_value(['paths', 'temp']),
                         'pids_from_polygon.txt'))
                with open(file, "r") as text_file:
                    pids = text_file.read().split('\n')
                outlog("Geting data form the parcels:")
                outlog(pids)
                if len(pids) <= plimit:
                    get_from_id(pids)
                else:
                    outlog(
                        "You exceeded the maximum amount of selected parcels ",
                        f"({plimit}) to get data. Please select smaller area.")
            except Exception as err:
                outlog("No pids file found.", err)
        else:
            outlog(f"Please select method to get parcel information.")

    return VBox([
        info, table_options, info_method, method, method_out, info_type,
        data_types, info_get, info_paths, paths_box, bt_get, progress
    ])
Пример #8
0
class TrafficWidget(object):

    # -- Constructor --
    def __init__(self, traffic: Traffic, projection=EuroPP()) -> None:

        ipython = get_ipython()
        ipython.magic("matplotlib ipympl")
        from ipympl.backend_nbagg import FigureCanvasNbAgg, FigureManagerNbAgg

        self.fig_map = Figure(figsize=(6, 6))
        self.fig_time = Figure(figsize=(6, 4))

        self.canvas_map = FigureCanvasNbAgg(self.fig_map)
        self.canvas_time = FigureCanvasNbAgg(self.fig_time)

        self.manager_map = FigureManagerNbAgg(self.canvas_map, 0)
        self.manager_time = FigureManagerNbAgg(self.canvas_time, 0)

        layout = {"width": "590px", "height": "800px", "border": "none"}
        self.output = Output(layout=layout)

        self._traffic = traffic
        self.t_view = traffic.sort_values("timestamp")
        self.trajectories: Dict[str, List[Artist]] = defaultdict(list)

        self.create_map(projection)

        self.projection = Dropdown(options=["EuroPP", "Lambert93", "Mercator"])
        self.projection.observe(self.on_projection_change)

        self.identifier_input = Text(description="Callsign/ID")
        self.identifier_input.observe(self.on_id_input)

        self.identifier_select = SelectMultiple(
            options=sorted(self._traffic.callsigns),  # type: ignore
            value=[],
            rows=20,
        )
        self.identifier_select.observe(self.on_id_change)

        self.area_input = Text(description="Area")
        self.area_input.observe(self.on_area_input)

        self.extent_button = Button(description="Extent")
        self.extent_button.on_click(self.on_extent_button)

        self.plot_button = Button(description="Plot")
        self.plot_button.on_click(self.on_plot_button)

        self.clear_button = Button(description="Reset")
        self.clear_button.on_click(self.on_clear_button)

        self.plot_airport = Button(description="Airport")
        self.plot_airport.on_click(self.on_plot_airport)

        self.area_select = SelectMultiple(options=[],
                                          value=[],
                                          rows=3,
                                          disabled=False)
        self.area_select.observe(self.on_area_click)

        self.altitude_select = SelectionRangeSlider(
            options=[0, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000],
            index=(0, 8),
            description="Altitude",
            disabled=False,
            continuous_update=False,
        )
        self.altitude_select.observe(self.on_altitude_select)

        self.time_slider = SelectionRangeSlider(
            options=list(range(100)),
            index=(0, 99),
            description="Date",
            continuous_update=False,
        )
        self.lock_time_change = False
        self.set_time_range()

        self.time_slider.observe(self.on_time_select)
        self.canvas_map.observe(self.on_axmap_change,
                                ["_button", "_png_is_old"])
        self.canvas_time.observe(self.on_axtime_change, ["_png_is_old"])

        col_options = []
        for column, dtype in self._traffic.data.dtypes.items():
            if column not in ("latitude", "longitude"):
                if dtype in ["float64", "int64"]:
                    col_options.append(column)

        self.y_selector = SelectMultiple(options=col_options,
                                         value=[],
                                         rows=5,
                                         disabled=False)
        self.y_selector.observe(self.on_id_change)

        self.sec_y_selector = SelectMultiple(options=col_options,
                                             value=[],
                                             rows=5,
                                             disabled=False)
        self.sec_y_selector.observe(self.on_id_change)

        self.time_tab = VBox(
            [HBox([self.y_selector, self.sec_y_selector]), self.canvas_time])

        self.tabs = Tab()
        self.tabs.children = [self.canvas_map, self.time_tab]
        self.tabs.set_title(0, "Map")
        self.tabs.set_title(1, "Plots")

        self._main_elt = HBox([
            self.tabs,
            VBox([
                self.projection,
                HBox([self.extent_button, self.plot_button]),
                HBox([self.plot_airport, self.clear_button]),
                self.area_input,
                self.area_select,
                self.time_slider,
                self.altitude_select,
                self.identifier_input,
                self.identifier_select,
            ]),
        ])

    @property
    def traffic(self) -> Traffic:
        return self._traffic

    def _ipython_display_(self) -> None:
        clear_output()
        self.canvas_map.draw_idle()
        self._main_elt._ipython_display_()

    def debug(self) -> None:
        if self.tabs.children[-1] != self.output:
            self.tabs.children = list(self.tabs.children) + [self.output]

    def set_time_range(self) -> None:
        with self.output:
            tz_now = datetime.now().astimezone().tzinfo
            self.dates = [
                self._traffic.start_time + i *
                (self._traffic.end_time - self._traffic.start_time) / 99
                for i in range(100)
            ]
            if self._traffic.start_time.tzinfo is not None:
                options = [
                    t.tz_convert("utc").strftime("%H:%M") for t in self.dates
                ]
            else:
                options = [
                    t.tz_localize(tz_now).tz_convert("utc").strftime("%H:%M")
                    for t in self.dates
                ]

            self.lock_time_change = True
            self.time_slider.options = options
            self.time_slider.index = (0, 99)
            self.lock_time_change = False

    def create_map(
        self,
        projection: Union[str, Projection] = "EuroPP()"  # type: ignore
    ) -> None:
        with self.output:
            if isinstance(projection, str):
                if not projection.endswith("()"):
                    projection = projection + "()"
                projection = eval(projection)

            self.projection = projection

            with plt.style.context("traffic"):

                self.fig_map.clear()
                self.trajectories.clear()
                self.ax_map = self.fig_map.add_subplot(
                    111, projection=self.projection)
                self.ax_map.add_feature(countries())
                if projection.__class__.__name__.split(".")[-1] in [
                        "Lambert93"
                ]:
                    self.ax_map.add_feature(rivers())

                self.fig_map.set_tight_layout(True)
                self.ax_map.background_patch.set_visible(False)
                self.ax_map.outline_patch.set_visible(False)
                self.ax_map.format_coord = lambda x, y: ""
                self.ax_map.set_global()

            self.default_plot()
            self.canvas_map.draw_idle()

    def default_plot(self) -> None:
        with self.output:
            # clear all trajectory pieces
            for key, value in self.trajectories.items():
                for elt in value:
                    elt.remove()
            self.trajectories.clear()
            self.ax_map.set_prop_cycle(None)

            lon_min, lon_max, lat_min, lat_max = self.ax_map.get_extent(
                PlateCarree())
            cur_flights = list(
                f.at() for f in self.t_view
                if lat_min <= getattr(f.at(), "latitude", -90) <= lat_max
                and lon_min <= getattr(f.at(), "longitude", -180) <= lon_max)

            def params(at):
                if len(cur_flights) < 10:
                    return dict(s=8, text_kw=dict(s=at.callsign))
                else:
                    return dict(s=8, text_kw=dict(s=""))

            for at in cur_flights:
                if at is not None:
                    self.trajectories[at.callsign] += at.plot(
                        self.ax_map, **params(at))

            self.canvas_map.draw_idle()

    def create_timeplot(self) -> None:
        with plt.style.context("traffic"):
            self.fig_time.clear()
            self.ax_time = self.fig_time.add_subplot(111)
            self.fig_time.set_tight_layout(True)

    # -- Callbacks --

    def on_projection_change(self, change: Dict[str, Any]) -> None:
        with self.output:
            if change["name"] == "value":
                self.create_map(change["new"])

    def on_clear_button(self, elt: Dict[str, Any]) -> None:
        with self.output:
            self.t_view = self.traffic.sort_values("timestamp")
            self.create_map(self.projection)
            self.create_timeplot()

    def on_area_input(self, elt: Dict[str, Any]) -> None:
        with self.output:
            if elt["name"] != "value":
                return
            search_text = elt["new"]
            if len(search_text) == 0:
                self.area_select.options = list()
            else:
                from ..data import airac

                self.area_select.options = list(
                    x.name for x in airac.parse(search_text))

    def on_area_click(self, elt: Dict[str, Any]) -> None:
        with self.output:
            if elt["name"] != "value":
                return
            from ..data import airac

            self.ax_map.set_extent(airac[elt["new"][0]])
            self.canvas_map.draw_idle()

    def on_extent_button(self, elt: Dict[str, Any]) -> None:
        with self.output:
            if len(self.area_select.value) == 0:
                if len(self.area_input.value) == 0:
                    self.ax_map.set_global()
                else:
                    self.ax_map.set_extent(location(self.area_input.value))
            else:
                from ..data import airac

                self.ax_map.set_extent(airac[self.area_select.value[0]])

            t1, t2 = self.time_slider.index
            low, up = self.altitude_select.value
            self.on_filter(low, up, t1, t2)
            self.canvas_map.draw_idle()

    def on_axtime_change(self, change: Dict[str, Any]) -> None:
        with self.output:
            if change["name"] == "_png_is_old":
                # go away!!
                return self.canvas_map.set_window_title("")

    def on_axmap_change(self, change: Dict[str, Any]) -> None:
        with self.output:
            if change["name"] == "_png_is_old":
                # go away!!
                return self.canvas_map.set_window_title("")
            if change["new"] is None:
                t1, t2 = self.time_slider.index
                low, up = self.altitude_select.value
                self.on_filter(low, up, t1, t2)

    def on_id_input(self, elt: Dict[str, Any]) -> None:
        with self.output:
            # typing issue because of the lru_cache_wrappen
            callsigns = cast(Set[str], self.t_view.callsigns)
            # low, up = alt.value
            self.identifier_select.options = sorted(
                callsign for callsign in callsigns if re.match(
                    elt["new"]["value"], callsign, flags=re.IGNORECASE))

    def on_plot_button(self, elt: Dict[str, Any]) -> None:
        with self.output:
            if len(self.area_select.value) == 0:
                if len(self.area_input.value) == 0:
                    return self.default_plot()
                location(self.area_input.value).plot(self.ax_map,
                                                     color="grey",
                                                     linestyle="dashed")
            else:
                from ..data import airac

                airspace = airac[self.area_select.value[0]]
                if airspace is not None:
                    airspace.plot(self.ax_map)
            self.canvas_map.draw_idle()

    def on_plot_airport(self, elt: Dict[str, Any]) -> None:
        with self.output:
            if len(self.area_input.value) == 0:
                from cartotools.osm import request, tags

                west, east, south, north = self.ax_map.get_extent(
                    crs=PlateCarree())
                if abs(east - west) > 1 or abs(north - south) > 1:
                    # that would be a too big request
                    return
                request((west, south, east, north),
                        **tags.airport).plot(self.ax_map)
            else:
                from ..data import airports

                airports[self.area_input.value].plot(self.ax_map)
            self.canvas_map.draw_idle()

    def on_id_change(self, change: Dict[str, Any]) -> None:
        with self.output:
            if change["name"] != "value":
                return

            y = self.y_selector.value + self.sec_y_selector.value
            secondary_y = self.sec_y_selector.value
            callsigns = self.identifier_select.value

            if len(y) == 0:
                y = ["altitude"]
            extra_dict = dict()
            if len(y) > 1:
                # just to avoid confusion...
                callsigns = callsigns[:1]

            # clear all trajectory pieces
            self.create_timeplot()
            for key, value in self.trajectories.items():
                for elt in value:
                    elt.remove()
            self.trajectories.clear()

            for callsign in callsigns:
                flight = self.t_view[callsign]
                if len(y) == 1:
                    extra_dict["label"] = callsign
                if flight is not None:
                    try:
                        self.trajectories[callsign] += flight.plot(self.ax_map)
                        at = flight.at()
                        if at is not None:
                            self.trajectories[callsign] += at.plot(
                                self.ax_map, s=8, text_kw=dict(s=callsign))
                    except Exception:  # NoneType object is not iterable
                        pass

                    try:
                        flight.plot_time(
                            self.ax_time,
                            y=y,
                            secondary_y=secondary_y,
                            **extra_dict,
                        )
                    except Exception:  # no numeric data to plot
                        pass

            if len(callsigns) == 0:
                self.default_plot()
            else:
                self.ax_time.legend()

            # non conformal with traffic style
            for elt in self.ax_time.get_xticklabels():
                elt.set_size(12)
            for elt in self.ax_time.get_yticklabels():
                elt.set_size(12)
            self.ax_time.set_xlabel("")

            self.canvas_map.draw_idle()
            self.canvas_time.draw_idle()

            if len(callsigns) != 0:
                low, up = self.ax_time.get_ylim()
                if (up - low) / up < 0.05:
                    self.ax_time.set_ylim(up - .05 * up, up + .05 * up)
                    self.canvas_time.draw_idle()

    def on_filter(self, low, up, t1, t2) -> None:
        with self.output:
            west, east, south, north = self.ax_map.get_extent(
                crs=PlateCarree())

            self.t_view = (self.traffic.between(
                self.dates[t1], self.dates[t2]).query(
                    f"{low} <= altitude <= {up} or altitude != altitude"
                ).query(
                    f"{west} <= longitude <= {east} and "
                    f"{south} <= latitude <= {north}").sort_values("timestamp")
                           )
            self.identifier_select.options = sorted(
                flight.callsign for flight in self.t_view
                if flight is not None and re.match(
                    self.identifier_input.value,
                    flight.callsign,
                    flags=re.IGNORECASE,
                ))
            return self.default_plot()

    def on_altitude_select(self, change: Dict[str, Any]) -> None:
        with self.output:
            if change["name"] != "value":
                return

            low, up = change["new"]
            t1, t2 = self.time_slider.index
            self.on_filter(low, up, t1, t2)

    def on_time_select(self, change: Dict[str, Any]) -> None:
        with self.output:
            if self.lock_time_change:
                return
            if change["name"] != "index":
                return
            t1, t2 = change["new"]
            low, up = self.altitude_select.value
            self.on_filter(low, up, t1, t2)
Пример #9
0
    def ui(self):
        """
        QA user interface
        """
        # Clear cell
        self.__clear_cell()

        if self.qa_def is None:
            # Use the first or unique QA
            self.qa_def = self.qa_defs[0]

        qa_flags = self.qa_def.Name.unique()
        qa_layer = self.qa_def.QualityLayer.unique()

        qa_layer_header = HTML(
            value = f"<b>{qa_layer[0]}</b>",
            description='QA layer:'
        )

        self.user_qa_selection = collections.OrderedDict(
                (element, '') for element in qa_flags)

        # Fill default selection
        for i, selection in enumerate(self.user_qa_selection):
            self.user_qa_selection[selection] = tuple(
                [self.qa_def[self.qa_def.Name == selection].Description.tolist()[0]]
            )

        qa_flag = Select(
            options=qa_flags,
            value=qa_flags[0],
            rows=len(qa_flags),
            description='QA Parameter name:',
            style = {'description_width': 'initial'},
            layout={'width': '400px'},
            disabled=False
        )

        def on_qa_flag_change(change):
            if change['type'] == 'change' and change['name'] == 'value':
                qa_flag_value = change.owner.value
        
                # Get user selection before changing qa description
                tmp_selection = self.user_qa_selection[qa_flag_value]

                _options = self.qa_def[self.qa_def.Name == qa_flag_value].Description.tolist()
                qa_description.options = _options
        
                qa_description.rows = len(_options)
                qa_description.value = tmp_selection
    
        qa_flag.observe(on_qa_flag_change)

        qa_description = SelectMultiple(
            options=tuple(
                self.qa_def[self.qa_def.Name == qa_flag.value].Description.tolist()
            ),
            value=tuple(
                [self.qa_def[self.qa_def.Name == qa_flag.value].Description.tolist()[0]]
            ),
            rows=len(self.qa_def[self.qa_def.Name == qa_flag.value].Description.tolist()),
            description='Description',
            disabled=False,
            style = {'description_width': 'initial'},
            layout={'width': '400px'}
        )

        def on_qa_description_change(change):
            if change['type'] == 'change' and change['name'] == 'value':
                self.user_qa_selection[qa_flag.value] = qa_description.value

        qa_description.observe(on_qa_description_change)

        def select_all_qa(b):
            for i, selection in enumerate(self.user_qa_selection):
                self.user_qa_selection[selection] = tuple(
                    self.qa_def[self.qa_def.Name == selection].Description.tolist()
                )
    
            qa_flag.value = qa_flags[0]
            qa_description.value = self.user_qa_selection[qa_flags[0]]

        # Select all button
        select_all = Button(
            description = 'Select ALL',
            layout={'width': '19%'}
        )

        select_all.on_click(select_all_qa)

        # Default selection
        select_default = Button(
            description = 'Default selection',
            layout={'width': '19%'}
        )

        def select_default_qa(b):
            # Fill default selection
            for i, selection in enumerate(self.user_qa_selection):
                self.user_qa_selection[selection] = tuple(
                    [self.qa_def[self.qa_def.Name == selection].Description.tolist()[0]]
                )
    
            qa_flag.value = qa_flags[0]
            qa_description.value = self.user_qa_selection[qa_flags[0]]

        select_default.on_click(select_default_qa)

        left_box = VBox([qa_flag])
        right_box = VBox([qa_description])
        #_HBox = HBox([qa_flag, right_box, select_all, select_default],
        _HBox_qa = HBox([left_box, right_box],
                        layout={'height': '300px',
                                'width' : '99%'}
        )

        analytics = Button(
            description = 'QA analytics',
            layout={'width': '19%'}
        )
        analytics.on_click(self._analytics)

        analytics_settings_save = Button(
            description = 'Save QA analytics',
            layout={'width': '19%'}
        )
        analytics_settings_save.on_click(self.__analytics_settings_save)

        # Load user-defined settings
        analytics_settings_load = Button(
            description = 'Load QA analytics',
            layout={'width': '19%'}
        )

        def __analytics_settings_load(b):
            # Load user-defined QA saved settings from a JSON file
            fname = self.qa_def.QualityLayer.unique()[0]
            fname = f"{fname}.json"
            fname = os.path.join(self.source_dir, fname)

            if os.path.exists(fname) is False:
                pass

            with open(fname, 'r') as f:
                self.user_qa_selection = collections.OrderedDict(
                        json.loads(f.read()))

            qa_flag.value = qa_flags[0]
            qa_description.value = self.user_qa_selection[qa_flags[0]]

        analytics_settings_load.on_click(__analytics_settings_load)

        # Display QA HBox
        display(qa_layer_header, _HBox_qa)
        
        _HBox_buttons = HBox([select_all, select_default, analytics,
                              analytics_settings_save,
                              analytics_settings_load])

        display(_HBox_buttons)
Пример #10
0
def generate_model_grid(df_X, number_of_models, models,
                        on_click_feature_exclude_button,
                        on_value_change_split_type_dropdown,
                        on_click_model_train_button):

    df_X_columns = list(df_X.columns)
    len_df_X_columns = len(df_X_columns)
    children = []
    min_number = 3

    # Row 1
    for i in range(number_of_models):
        children.append(
            Label(layout=Layout(width='auto', height='auto'),
                  value='Remove features for model {}'.format(i + 1)))
    # Row 1: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 2
    for i in range(number_of_models):
        w = SelectMultiple(
            options=df_X_columns,
            rows=len_df_X_columns if len_df_X_columns <= 20 else 20,
            layout=Layout(width='auto', height='auto'))
        model = get_model_by_id(models, i)
        model.remove_features_sm = w
        children.append(w)
    # Row 2: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 3
    for i in range(number_of_models):
        w = Button(description='Remove features',
                   disabled=False,
                   button_style='danger',
                   tooltip='Click me',
                   icon='trash',
                   layout=Layout(width='auto', height='auto'))
        w.on_click(on_click_feature_exclude_button)
        model = get_model_by_id(models, i)
        model.remove_features_button = w
        children.append(w)
    # Row 3: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 4:
    for i in range(number_of_models):
        children.append(
            Label(layout=Layout(width='auto', height='auto'),
                  value='Train model {}'.format(i + 1)))
    # Row 4: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 5:
    for i in range(number_of_models):
        model = get_model_by_id(models, i)
        w = Dropdown(options=model.model_type.algorithm_options,
                     description='Model type:',
                     disabled=False,
                     layout=Layout(width='auto', height='auto'))
        model.model_type_dd = w
        children.append(w)
    # Row 5: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 6:
    for i in range(number_of_models):
        w = Dropdown(
            options=[s.name for s in SplitTypes],
            description='Train/Test split type:',
            disabled=False,
            layout=Layout(width='auto', height='auto'),
            description_tooltip=
            'Splits the features and the target into train/test split training '
            'sets with a balanced number of examples for each of the categories of'
            ' the columns provided. For example, if the columns provided are '
            '“gender” and “loan”, the resulting splits would contain an equal '
            'number of examples for Male with Loan Approved, Male with '
            'Loan Rejected, Female with Loan Approved, and Female with '
            'Loan Rejected.')
        model = get_model_by_id(models, i)
        w.observe(on_value_change_split_type_dropdown, names='value')
        model.split_type_dd = w
        children.append(w)
    # Row 6: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 7:
    for i in range(number_of_models):
        children.append(
            Label(layout=Layout(width='auto', height='auto'),
                  value='Cross columns for model {}'.format(i + 1)))
    # Row 7: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 8:
    for i in range(number_of_models):
        model = get_model_by_id(models, i)
        w = SelectMultiple(
            options=model.X,
            rows=8 if len_df_X_columns <= 20 else 20,
            layout=Layout(width='auto', height='auto'),
            description='',
            disabled=True,
            description_tooltip=
            'One or more positional arguments (passed as *args) '
            'that are used to split the data into the cross product '
            'of their values.')
        model.cross_columns_sm = w
        children.append(w)
    # Row 8: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    # Row 9:
    for i in range(number_of_models):
        w = Button(description='Train model',
                   disabled=False,
                   button_style='success',
                   tooltip='Click me',
                   icon='cogs',
                   layout=Layout(width='auto', height='auto'))
        w.on_click(on_click_model_train_button)
        model = get_model_by_id(models, i)
        model.train_model_button = w
        children.append(w)
    # Row 9: add dummy widgets
    add_dummy_widgets(min_number, children, number_of_models)

    return GridBox(
        children=children,
        layout=Layout(
            width='auto',
            grid_template_columns=get_grid_template_columns(
                number_of_models, min_number),
            align_items='center',
            # grid_template_columns='auto auto auto',
            grid_template_rows='auto auto auto',
            grid_gap='1px 1px'))
Пример #11
0
class Arulesviz:
    def __init__(
        self,
        transactions,
        min_sup,
        min_conf,
        min_lift,
        max_sup=1.0,
        min_slift=0.1,
        products_to_drop=[],
    ):
        self.rules = []
        self.transactions = transactions
        self.min_lift = min_lift
        self.min_slift = min_slift
        self.min_slift = min_slift or min_lift
        self.min_sup = min_sup
        self.min_conf = min_conf
        self.max_sup = max_sup
        self.products_to_in = []
        self.products_to_out = products_to_drop
        self._hovered_product = None

    def _standardized_lift(self, rule, s=None, c=None):
        """
        Parameters
        ----------
        rule:
              Target rule
        s: float
           Support treshold user for rule mining
        c: float
           Confidence treshold user for rule mining
        """
        s = s or self.min_sup
        c = c or self.min_conf
        prob_A = getattr(rule, "support") / getattr(rule, "confidence")
        prob_B = getattr(rule, "confidence") / getattr(rule, "lift")
        mult_A_and_B = prob_A * prob_B
        L = max(
            1 / prob_A + 1 / prob_B - 1 / (mult_A_and_B),
            s / mult_A_and_B,
            c / prob_B,
            0,
        )
        U = min(1 / prob_A, 1 / prob_B)
        slift = (getattr(rule, "lift") - L) / (U - L)
        return slift

    def create_rules(self, drop_products=True, max_sup=None):
        max_sup = max_sup or self.max_sup
        tr = self.transactions
        if drop_products:
            to_drop = set(self.products_to_out)
            tr = [set(x) - to_drop for x in tr]
            tr = [x for x in tr if x]
        _, self.rules = apriori(
            tr, min_support=self.min_sup, min_confidence=self.min_conf
        )
        for rule in self.rules:
            setattr(rule, "slift", self._standardized_lift(rule))
        self.rules = self.filter_numeric(
            "support", max_sup, self.rules, should_be_lower=True
        )
        self._max_sup = max([x.support for x in self.rules])
        self._max_conf = max([x.confidence for x in self.rules])

    def filter_numeric(self, atr, val, rules, should_be_lower=False):
        rules = rules
        if should_be_lower:
            return [x for x in rules if getattr(x, atr) < val]
        return [x for x in rules if getattr(x, atr) > val]

    def filter_drop_if_name_in(self, vals, rules, lhs=True, rhs=True):
        rules = rules
        vals = set(vals)
        f = lambda x: not any(
            [(lhs and (vals & set(x.lhs))), (rhs and (vals & set(x.rhs)))]
        )
        return list(filter(f, rules))

    def filter_drop_if_name_out(self, vals, rules, lhs=True, rhs=True):
        rules = rules
        vals = set(vals)
        f = lambda x: any(
            [(lhs and (vals & set(x.lhs))), (rhs and (vals & set(x.rhs)))]
        )
        return list(filter(f, rules))

    def get_unique_products(self, rules):
        rules = rules
        return reduce(
            lambda x, y: (x if isinstance(x, set) else set(x.lhs) | set(x.rhs))
            | set(y.lhs)
            | set(y.rhs),
            rules,
        )

    def create_graph(self, rules):
        rules = rules
        nodes = []
        links = []
        colors = []
        name_to_id = {}
        already_seen = set()
        for sr in rules:
            current_comb = tuple(sorted(set(sr.lhs) | set(sr.rhs)))
            if current_comb in already_seen:
                continue
            else:
                already_seen.add(current_comb)
            # node_size = max(min(sr.lift * 10, 30), 5)
            nodes.append(
                {
                    "label": f".",
                    "shape": "circle",
                    "shape_attrs": {"r": max(min(sr.lift, 7), 2)},
                    "is_rule": True,
                    "tooltip": str(sr),
                }
            )
            colors.append("black")
            rule_id = len(nodes) - 1

            for node_name in sr.lhs:
                l_node_id = name_to_id.get(node_name, None)
                if l_node_id == None:
                    nodes.append(
                        {
                            "label": node_name,
                            "shape": "rect",
                            "is_rule": False,
                            "shape_attrs": {
                                "width": 6 * len(node_name) + 8,
                                "height": 20,
                            },
                        }
                    )
                    colors.append("white")
                    l_node_id = len(nodes) - 1
                    name_to_id[node_name] = l_node_id
                links.append({"source": l_node_id, "target": rule_id, "value": sr.lift})

            for node_name in sr.rhs:
                r_node_id = name_to_id.get(node_name, None)
                if r_node_id == None:
                    nodes.append(
                        {
                            "label": node_name,
                            "shape": "rect",
                            "is_rule": False,
                            "shape_attrs": {
                                "width": 6 * len(node_name) + 8,
                                "height": 20,
                            },
                        }
                    )
                    r_node_id = len(nodes) - 1
                    name_to_id[node_name] = r_node_id
                    colors.append("white")
                links.append({"source": rule_id, "target": r_node_id, "value": sr.lift})
        return nodes, links, colors

    def replot_graph(self):
        sub_rules = self.filter_numeric("lift", self.min_lift, rules=self.rules)
        sub_rules = self.filter_numeric("support", self.min_sup, rules=sub_rules)
        sub_rules = self.filter_numeric("slift", self.min_slift, rules=sub_rules)
        sub_rules = self.filter_numeric("confidence", self.min_conf, rules=sub_rules)
        sub_rules = self.filter_drop_if_name_in(self.products_to_out, rules=sub_rules)
        sub_rules = self.filter_drop_if_name_out(self.products_to_in, rules=sub_rules)
        (
            self.graph.node_data,
            self.graph.link_data,
            _,  # self.graph.colors,
        ) = self.create_graph(sub_rules)

    def handler_products_out_filter(self, value):
        self.products_to_out = value["new"]
        self.replot_graph()

    def setup_products_out_selector(self):
        self.selector_products_out = SelectMultiple(
            options=sorted(self.get_unique_products(self.rules)),
            value=[],
            rows=10,
            # description="Drop",
            disabled=False,
        )
        self.selector_products_out.observe(self.handler_products_out_filter, "value")

    def handler_products_in_filter(self, value):
        self.products_to_in = value["new"]
        self.replot_graph()

    def setup_products_in_selector(self):
        self.selector_products_in = SelectMultiple(
            options=sorted(self.get_unique_products(self.rules)),
            value=sorted(self.get_unique_products(self.rules)),
            rows=10,
            # description="Include",
            disabled=False,
        )
        self.products_to_in = sorted(self.get_unique_products(self.rules))
        self.selector_products_in.observe(self.handler_products_in_filter, "value")

    def set_slider_value(self, value):
        setattr(self, getattr(value["owner"], "description"), value["new"])
        self.replot_graph()

    def setup_lift_slider(self):
        name = "lift"
        setattr(
            self,
            f"slider_{name}",
            FloatLogSlider(
                value=getattr(self, f"min_{name}"),
                min=-0.5,
                max=1.5,
                step=0.05,
                base=10,
                description=f"min_{name}",
                disabled=False,
                continuous_update=False,
                orientation="horizontal",
                readout=True,
                readout_format=".3f",
            ),
        )
        getattr(self, f"slider_{name}").observe(self.set_slider_value, "value")

    def setup_conf_slider(self):
        name = "conf"
        setattr(
            self,
            f"slider_{name}",
            FloatSlider(
                value=getattr(self, f"min_{name}"),
                min=0.0,
                max=self._max_conf,
                step=0.0001,
                base=10,
                description=f"min_{name}",
                disabled=False,
                continuous_update=False,
                orientation="horizontal",
                readout=True,
                readout_format=".5f",
            ),
        )
        getattr(self, f"slider_{name}").observe(self.set_slider_value, "value")

    def setup_slift_slider(self):
        name = "slift"
        setattr(
            self,
            f"slider_{name}",
            FloatSlider(
                value=getattr(self, f"min_{name}"),
                min=0.0,
                max=1.0,
                step=0.0001,
                base=10,
                description=f"min_{name}",
                disabled=False,
                continuous_update=False,
                orientation="horizontal",
                readout=True,
                readout_format=".5f",
            ),
        )
        getattr(self, f"slider_{name}").observe(self.set_slider_value, "value")

    def setup_sup_slider(self):
        name = "sup"
        setattr(
            self,
            f"slider_{name}",
            FloatSlider(
                value=getattr(self, f"min_{name}"),
                min=0.0,
                max=self._max_sup,
                step=0.0001,
                base=10,
                description=f"min_{name}",
                disabled=False,
                continuous_update=False,
                orientation="horizontal",
                readout=True,
                readout_format=".5f",
            ),
        )
        getattr(self, f"slider_{name}").observe(self.set_slider_value, "value")

    def _save_graph_img(self, b):
        self.fig.save_png(
            f"arulesviz_{datetime.datetime.now().isoformat().replace(':','-').split('.')[0]}.png"
        )

    def setup_graph_to_img_button(self):
        self.graph_to_img_button = Button(description="Save img!")
        self.graph_to_img_button.on_click(self._save_graph_img)

    def plot_graph(
        self,
        width=1000,
        height=750,
        charge=-200,
        link_type="arc",
        directed=True,
        link_distance=100,
    ):
        fig_layout = Layout(width=f"{width}px", height=f"{height}px")
        nodes, links, colors = self.create_graph(
            self.filter_numeric("lift", self.min_lift, rules=self.rules)
        )
        # xs = LinearScale(min=0, max=1000)
        # ys = LinearScale(min=0, max=750)
        cs = ColorScale(scheme="Reds")
        self.graph = Graph(
            node_data=nodes,
            link_data=links,
            # colors=colors,
            charge=charge,
            link_type=link_type,
            directed=directed,
            link_distance=link_distance,
            # scales={'color': cs}
        )
        margin = dict(top=-60, bottom=-60, left=-60, right=-60)
        self.fig = Figure(
            marks=[self.graph],
            layout=Layout(width=f"{width}px", height=f"{height}px"),
            fig_margin=dict(top=0, bottom=0, left=0, right=0),
            legend_text={"font-size": 7},
        )

        # tooltip = Tooltip(fields=["foo"], formats=["", "", ""])
        # self.graph.tooltip = tooltip

        # self.graph.on_hover(self.hover_handler)
        self.graph.on_element_click(self.hover_handler)
        self.graph.on_background_click(self.clean_tooltip)
        self.graph.interactions = {"click": "tooltip"}
        self.setup_sup_slider()
        self.setup_lift_slider()
        self.setup_conf_slider()
        self.setup_slift_slider()
        self.setup_products_in_selector()
        self.setup_products_out_selector()
        self.setup_graph_to_img_button()
        self.setup_product_tooltip()
        return VBox(
            [
                HBox(
                    [
                        self.selector_products_in,
                        self.selector_products_out,
                        VBox(
                            [
                                getattr(self, "slider_lift"),
                                getattr(self, "slider_slift"),
                                getattr(self, "slider_conf"),
                                getattr(self, "slider_sup"),
                            ]
                        ),
                        getattr(self, "graph_to_img_button"),
                    ]
                ),
                self.fig,
            ]
        )

    def clean_tooltip(self, x, y):
        self.graph.tooltip = None

    def plot_scatter(
        self,
        products=[],
        min_width=600,
        min_height=600,
        max_width=600,
        max_height=600,
        with_toolbar=True,
        display_names=False,
    ):
        if products:
            sub_rules = self.filter_drop_if_name_out(products, self.rules)
        else:
            sub_rules = self.rules
        data_x = [np.round(x.support * 100, 3) for x in sub_rules]
        data_y = [np.round(x.confidence * 100, 3) for x in sub_rules]
        color = [np.round(x.lift, 4) for x in sub_rules]
        names = [str(sr) for sr in sub_rules]
        sc_x = LinearScale()
        sc_y = LinearScale()
        sc_color = ColorScale(scheme="Reds")
        ax_c = ColorAxis(
            scale=sc_color,
            tick_format="",
            label="Lift",
            orientation="vertical",
            side="right",
        )
        tt = Tooltip(fields=["name"], formats=[""])
        scatt = Scatter(
            x=data_x,
            y=data_y,
            color=color,
            scales={"x": sc_x, "y": sc_y, "color": sc_color},
            tooltip=tt,
            names=names,
            display_names=display_names,
        )
        ax_x = Axis(scale=sc_x, label="Sup*100")
        ax_y = Axis(scale=sc_y, label="Conf*100", orientation="vertical")
        m_chart = dict(top=50, bottom=70, left=50, right=100)
        fig = Figure(
            marks=[scatt],
            axes=[ax_x, ax_y, ax_c],
            fig_margin=m_chart,
            layout=Layout(
                min_width=f"{min_width}px",
                min_height=f"{min_height}px",
                max_width=f"{max_width}px",
                max_height=f"{max_height}px",
            ),
        )
        if with_toolbar:
            toolbar = Toolbar(figure=fig)
            return VBox([fig, toolbar])
        else:
            return fig

    def setup_product_tooltip(self, products=[]):
        self.graph.tooltip = self.plot_scatter(products)
        if len(products) == 1:
            self.graph.tooltip.title = products[-1]
        else:
            self.graph.tooltip.title = "Products scatter"

    def hover_handler(self, qq, content):
        product = content.get("data", {}).get("label", -1)
        is_rule = content.get("data", {}).get("tooltip", None)
        if product != self._hovered_product:
            if is_rule:
                self._hovered_product = content.get("data", {}).get("tooltip", None)
                self.graph.tooltip = Textarea(
                    content.get("data", {}).get("tooltip", None)
                )
                self.graph.tooltip_location = "center"
            else:
                self._hovered_product = product
                self.setup_product_tooltip([product])
                self.graph.tooltip_location = "center"
Пример #12
0
class Text(TextTrainerMixin):
    def __init__(
            self,
            sname: str,
            *,
            mllib: str = "caffe",
            engine: Engine = "CUDNN_SINGLE_HANDLE",
            training_repo: Path,
            testing_repo: Optional[List[Path]] = None,
            description: str = "Text service",
            model_repo: Path = None,
            host: str = "localhost",
            port: int = 1234,
            path: str = "",
            gpuid: GPUIndex = 0,
            # -- specific
            regression: bool = False,
            db: bool = True,
            nclasses: int = -1,
            ignore_label: Optional[int] = -1,
            layers: List[str] = [],
            dropout: float = .2,
            iterations: int = 25000,
            test_interval: int = 1000,
            snapshot_interval: int = 1000,
            base_lr: float = 0.001,
            lr_policy: str = "fixed",
            stepvalue: List[int] = [],
            warmup_lr: float = 0.0001,
            warmup_iter: int = 0,
            resume: bool = False,
            solver_type: Solver = "SGD",
            sam: bool = False,
            swa: bool = False,
            lookahead: bool = False,
            lookahead_steps: int = 6,
            lookahead_alpha: float = 0.5,
            rectified: bool = False,
            decoupled_wd_periods: int = 4,
            decoupled_wd_mult: float = 2.0,
            lr_dropout: float = 1.0,
            batch_size: int = 128,
            test_batch_size: int = 32,
            shuffle: bool = True,
            tsplit: float = 0.2,
            min_count: int = 10,
            min_word_length: int = 5,
            count: bool = False,
            tfidf: bool = False,
            sentences: bool = False,
            characters: bool = False,
            sequence: int = -1,
            read_forward: bool = True,
            alphabet: str = alpha,
            sparse: bool = False,
            template: Optional[str] = None,
            activation: str = "relu",
            embedding: bool = False,
            objective: str = '',
            class_weights: List[float] = [],
            scale_pos_weight: float = 1.0,
            autoencoder: bool = False,
            lregression: bool = False,
            finetune: bool = False,
            weights: str = "",
            iter_size: int = 1,
            target_repository: str = "",
            ##-- new txt input conns stuff for bert and gpt2
            ordered_words: bool = True,
            wordpiece_tokens: bool = True,
            punctuation_tokens: bool = True,
            lower_case: bool = False,
            word_start: str = "Ġ",
            suffix_start: str = "",
            ##--end bert, gpt2 new stuff
            embedding_size: int = 768,
            freeze_traced: bool = False,
            **kwargs) -> None:

        super().__init__(sname, locals())

        self.train_labels = SelectMultiple(options=[],
                                           value=[],
                                           description="Training labels",
                                           disabled=False)

        self.test_labels = SelectMultiple(options=[],
                                          value=[],
                                          description="Testing labels",
                                          disabled=False)

        # self.testing_repo.observe(self.update_label_list, names="value")
        self.training_repo.observe(  # type: ignore
            self.update_label_list, names="value")

        self.train_labels.observe(self.update_train_file_list, names="value")
        self.test_labels.observe(self.update_test_file_list, names="value")
        self.file_list.observe(self.display_text, names="value")

        self.update_label_list(())

        self._img_explorer.children = [
            HBox([HBox([self.train_labels, self.test_labels])]),
            self.file_list,
            self.output,
        ]

        if self.characters:  # type: ignore
            self.db.value = True  # type: ignore

        if self.mllib.value == "torch":
            self.db.value = False

    def display_text(self, args):
        self.output.clear_output()
        with self.output:
            for path in args["new"]:
                with open(path, "r", encoding="utf-8", errors="ignore") as fh:
                    for i, x in enumerate(fh.readlines()):
                        if i == 20:
                            break
                        print(x.strip())

    def update_train_file_list(self, *args):
        with self.output:
            if len(self.train_labels.value) == 0:
                return
            directory = (Path(self.training_repo.value) /
                         self.train_labels.value[0])
            self.file_list.options = [
                fh.as_posix()
                for fh in sample_from_iterable(directory.glob("**/*"), 10)
            ]
            self.test_labels.value = []

    def update_test_file_list(self, *args):
        with self.output:
            if len(self.test_labels.value) == 0:
                return
            directory = (Path(self.testing_repo.value) /
                         self.test_labels.value[0])
            self.file_list.options = [
                fh.as_posix()
                for fh in sample_from_iterable(directory.glob("**/*"), 10)
            ]
            self.train_labels.value = []

    def _create_parameters_input(self) -> JSONType:
        return {
            "connector": "txt",
            "characters": self.characters.value,
            "sequence": self.sequence.value,
            "read_forward": self.read_forward.value,
            "alphabet": self.alphabet.value,
            "sparse": self.sparse.value,
            "embedding": self.embedding.value,
            "ordered_words": self.ordered_words.value,
            "wordpiece_tokens": self.wordpiece_tokens.value,
            "punctuation_tokens": self.punctuation_tokens.value,
            "lower_case": self.lower_case.value,
            "word_start": self.word_start.value,
            "suffix_start": self.suffix_start.value,
        }

    def _create_parameters_mllib(self) -> JSONType:
        dic = super()._create_parameters_mllib()
        dic["embedding_size"] = self.embedding_size.value
        dic["freeze_traced"] = self.freeze_traced.value
        return dic

    def _train_parameters_input(self) -> JSONType:
        return {
            "alphabet": self.alphabet.value,
            "characters": self.characters.value,
            "count": self.count.value,
            "db": self.db.value,
            "embedding": self.embedding.value,
            "min_count": self.min_count.value,
            "min_word_length": self.min_word_length.value,
            "read_forward": self.read_forward.value,
            "sentences": self.sentences.value,
            "sequence": self.sequence.value,
            "shuffle": self.shuffle.value,
            "test_split": self.tsplit.value,
            "tfidf": self.tfidf.value,
        }
Пример #13
0
def get():
    """Get the parcel's dataset for the given location or ids"""
    info = Label(
        "1. Select the region and the year to get parcel information.")

    values = config.read()
    # Set the max number of parcels that can be downloaded at once.
    plimit = int(values['set']['plimit'])

    def aois_options():
        values = config.read()
        options = {}
        if values['set']['data_source'] == '0':
            for desc in values['api']['options']['aois']:
                aoi = f"{values['api']['options']['aois'][desc]}"
                options[(desc, aoi)] = values['api']['options']['years'][aoi]
        elif values['set']['data_source'] == '1':
            for aoi in values['ds_conf']:
                desc = f"{values['ds_conf'][aoi]['desc']}"
                confgs = values['ds_conf'][aoi]['years']
                options[(f'{desc} ({aoi})', aoi)] = [y for y in confgs]
        return options

    def aois_years():
        values = config.read()
        years = {}
        if values['set']['data_source'] == '0':
            for desc in values['api']['options']['aois']:
                aoi = values['api']['options']['aois'][desc]
                years[aoi] = values['api']['options']['years'][aoi]
        elif values['set']['data_source'] == '1':
            for aoi in values['ds_conf']:
                desc = f"{values['ds_conf'][aoi]['desc']}"
                years[aoi] = [y for y in values['ds_conf'][aoi]['years']]
        return years

    try:
        aois = Dropdown(
            options=tuple(aois_options()),
            value=values['set']['ds_conf'],
            description='AOI:',
            disabled=False,
        )
    except:
        aois = Dropdown(
            options=tuple(aois_options()),
            description='AOI:',
            disabled=False,
        )

    year = Dropdown(
        options=next(iter(aois_options().values())),
        description='Year:',
        disabled=False,
    )
    button_refresh = Button(layout=Layout(width='35px'), icon='fa-refresh')

    @button_refresh.on_click
    def button_refresh_on_click(b):
        aois.options = tuple(aois_options())
        year.options = aois_years()[aois.value]

    def table_options_change(change):
        try:
            year.options = aois_years()[change.new]
        except:
            aois.options = tuple(aois_options())
            year.options = aois_years()[aois.value]

    aois.observe(table_options_change, 'value')

    info_method = Label("2. Select a method to get the data.")

    method = ToggleButtons(
        options=[('Parcel ID', 2), ('Coordinates', 1), ('Map marker', 3),
                 ('Polygon', 4)],
        value=None,
        description='',
        disabled=False,
        button_style='info',
        tooltips=[
            'Enter lat lon', 'Enter parcel ID', 'Select a point on a map',
            'Get parcels id in a polygon'
        ],
    )

    plon = Text(value='5.664',
                placeholder='Add lon',
                description='Lon:',
                disabled=False)

    plat = Text(value='52.694',
                placeholder='Add lat',
                description='Lat:',
                disabled=False)

    wbox_lat_lot = VBox(children=[plat, plon])

    info_pid = Label(
        "Multiple parcel id codes can be added (comma ',' separated, e.g.: 11111, 22222)."
    )

    pid = Textarea(value='34296',
                   placeholder='12345, 67890',
                   description='Parcel(s) ID:',
                   disabled=False)

    wbox_pids = VBox(children=[info_pid, pid])

    bt_get_ids = Button(description="Find parcels",
                        disabled=False,
                        button_style='info',
                        tooltip='Find parcels within the polygon.',
                        icon='')

    get_ids_box = HBox(
        [bt_get_ids,
         Label("Find the parcels that are in the polygon.")])

    ppoly_out = Output()

    progress = Output()

    def outlog(*text):
        with progress:
            print(*text)

    def outlog_poly(*text):
        with ppoly_out:
            print(*text)

    @bt_get_ids.on_click
    def bt_get_ids_on_click(b):
        with ppoly_out:
            try:
                get_requests = data_source()
                ppoly_out.clear_output()
                polygon = get_maps.polygon_map.feature_collection['features'][
                    -1]['geometry']['coordinates'][0]
                polygon_str = '-'.join(
                    ['_'.join(map(str, c)) for c in polygon])
                outlog_poly(f"Geting parcel ids within the polygon...")
                polyids = json.loads(
                    get_requests.ppoly(aois.value, year.value, polygon_str,
                                       False, True))
                outlog_poly(
                    f"'{len(polyids['ogc_fid'])}' parcels where found:")
                outlog_poly(polyids['ogc_fid'])
                file = config.get_value(['files', 'pids_poly'])
                with open(file, "w") as text_file:
                    text_file.write('\n'.join(map(str, polyids['ogc_fid'])))
            except Exception as err:
                outlog("No parcel ids found:", err)

    method_out = Output(layout=Layout(border='1px solid black'))

    def method_options(obj):
        with method_out:
            method_out.clear_output()
            if obj['new'] == 1:
                display(wbox_lat_lot)
            elif obj['new'] == 2:
                display(wbox_pids)
            elif obj['new'] == 3:
                display(
                    get_maps.base_map(
                        aois.value,
                        int(config.get_value(['set', 'data_source']))))
            elif obj['new'] == 4:
                display(
                    VBox([
                        get_maps.polygon(
                            aois.value,
                            int(config.get_value(['set', 'data_source']))),
                        get_ids_box, ppoly_out
                    ]))

    method.observe(method_options, 'value')

    info_type = Label("3. Select datasets to download.")

    table_options = HBox([aois, button_refresh, year])

    # ########### Time series options #########################################
    pts_bt = ToggleButton(
        value=False,
        description='Time series',
        disabled=False,
        button_style='success',  # success
        tooltip='Get parcel information',
        icon='toggle-off',
        layout=Layout(width='50%'))

    pts_bands = data_options.pts_bands()

    pts_tstype = SelectMultiple(
        options=data_options.pts_tstype(),
        value=['s2'],
        rows=3,
        description='TS type:',
        disabled=False,
    )

    pts_band = Dropdown(
        options=list(pts_bands['s2']),
        value='',
        description='Band:',
        disabled=False,
    )

    def pts_tstype_change(change):
        if len(pts_tstype.value) <= 1:
            pts_band.disabled = False
            try:
                pts_b = change.new[0]
                pts_band.options = pts_bands[pts_b]
            except:
                pass
        else:
            pts_band.value = ''
            pts_band.disabled = True

    pts_tstype.observe(pts_tstype_change, 'value')

    pts_options = VBox(children=[pts_tstype, pts_band])

    # ########### Chip images options #########################################
    pci_bt = ToggleButton(value=False,
                          description='Chip images',
                          disabled=False,
                          button_style='success',
                          tooltip='Get parcel information',
                          icon='toggle-off',
                          layout=Layout(width='50%'))

    pci_start_date = DatePicker(value=datetime.date(2019, 6, 1),
                                description='Start Date',
                                disabled=False)

    pci_end_date = DatePicker(value=datetime.date(2019, 6, 30),
                              description='End Date',
                              disabled=False)

    pci_plevel = RadioButtons(
        options=['LEVEL2A', 'LEVEL1C'],
        value='LEVEL2A',
        description='Proces. level:',  # Processing level
        disabled=False,
        layout=Layout(width='50%'))

    pci_chipsize = IntSlider(value=640,
                             min=100,
                             max=5120,
                             step=10,
                             description='Chip size:',
                             disabled=False,
                             continuous_update=False,
                             orientation='horizontal',
                             readout=True,
                             readout_format='d')

    pci_bands = data_options.pci_bands()

    pci_satellite = RadioButtons(options=list(pci_bands),
                                 value='Sentinel 2',
                                 disabled=True,
                                 layout=Layout(width='100px'))

    pci_band = SelectMultiple(options=list(pci_bands['Sentinel 2']),
                              value=['B04'],
                              rows=11,
                              description='Band:',
                              disabled=False)

    sats_plevel = HBox([pci_satellite, pci_plevel])

    def on_sat_change(change):
        sat = change.new
        pci_band.options = pci_bands[sat]

    pci_satellite.observe(on_sat_change, 'value')

    pci_options = VBox(children=[
        pci_start_date, pci_end_date, sats_plevel, pci_chipsize, pci_band
    ])

    # ########### General options #############################################
    pts_wbox = VBox(children=[])
    pci_wbox = VBox(children=[])

    def pts_observe(button):
        if button['new']:
            pts_bt.icon = 'toggle-on'
            pts_wbox.children = [pts_options]
        else:
            pts_bt.icon = 'toggle-off'
            pts_wbox.children = []

    def pci_observe(button):
        if button['new']:
            pci_bt.icon = 'toggle-on'
            pci_wbox.children = [pci_options]
        else:
            pci_bt.icon = 'toggle-off'
            pci_wbox.children = []

    pts_bt.observe(pts_observe, names='value')
    pci_bt.observe(pci_observe, names='value')

    pts = VBox(children=[pts_bt, pts_wbox], layout=Layout(width='40%'))
    pci = VBox(children=[pci_bt, pci_wbox], layout=Layout(width='40%'))

    data_types = HBox(children=[pts, pci])

    info_get = Label("4. Download the selected data.")

    bt_get = Button(description='Download',
                    disabled=False,
                    button_style='warning',
                    tooltip='Send the request',
                    icon='download')

    path_temp = config.get_value(['paths', 'temp'])
    path_data = config.get_value(['paths', 'data'])

    info_paths = HTML("".join([
        "<style>div.c {line-height: 1.1;}</style>",
        "<div class='c';>By default data will be stored in the temp folder ",
        f"({path_temp}), you will be asked to empty the temp folder each time ",
        "you start the notebook.<br>In your personal data folder ",
        f"({path_data}) you can permanently store the data.</div>"
    ]))

    paths = RadioButtons(options=[
        (f"Temporary folder: '{path_temp}'.", path_temp),
        (f"Personal data folder: '{path_data}'.", path_data)
    ],
                         layout={'width': 'max-content'},
                         value=path_temp)

    paths_box = Box([Label(value="Select folder:"), paths])

    def file_len(fname):
        with open(fname) as f:
            for i, l in enumerate(f):
                pass
        return i + 1

    def get_data(parcel):
        values = config.read()
        get_requests = data_source()
        pid = parcel['ogc_fid'][0]
        source = int(config.get_value(['set', 'data_source']))
        if source == 0:
            datapath = f'{paths.value}{aois.value}{year.value}/parcel_{pid}/'
        elif source == 1:
            ds_conf = config.get_value(['set', 'ds_conf'])
            datapath = f'{paths.value}{ds_conf}/parcel_{pid}/'
        file_pinf = f"{datapath}{pid}_information"

        outlog(data_handler.export(parcel, 10, file_pinf))

        if pts_bt.value is True:
            outlog(f"Getting time series for parcel: '{pid}',",
                   f"({pts_tstype.value} {pts_band.value}).")
            for pts in pts_tstype.value:
                ts = json.loads(
                    get_requests.pts(aois.value, year.value, pid, pts,
                                     pts_band.value))
                band = ''
                if pts_band.value != '':
                    band = f"_{pts_band.value}"
                file_ts = f"{datapath}{pid}_time_series_{pts}{band}"
                outlog(data_handler.export(ts, 11, file_ts))
        if pci_bt.value is True:
            files_pci = f"{datapath}{pid}_chip_images/"
            outlog(f"Getting '{pci_band.value}' chip images for parcel: {pid}")
            with progress:
                get_requests.rcbl(parcel, pci_start_date.value,
                                  pci_end_date.value, pci_band.value,
                                  pci_satellite.value, pci_chipsize.value,
                                  files_pci)
            filet = f'{datapath}/{pid}_chip_images/{pid}_images_list.{pci_band.value[0]}.csv'
            if file_len(filet) > 1:
                outlog(
                    f"Completed, all GeoTIFFs for bands '{pci_band.value}' are ",
                    f"downloaded in the folder: '{datapath}/{pid}_chip_images'"
                )
            else:
                outlog(
                    "No files where downloaded, please check your configurations"
                )

    def get_from_location(lon, lat):
        get_requests = data_source()
        outlog(f"Finding parcel information for coordinates: {lon}, {lat}")
        parcel = json.loads(
            get_requests.ploc(aois.value, year.value, lon, lat, True))
        pid = parcel['ogc_fid'][0]
        outlog(f"The parcel '{pid}' was found at this location.")
        try:
            get_data(parcel)
        except Exception as err:
            print(err)

    def get_from_id(pids):
        get_requests = data_source()
        outlog(f"Getting parcels information for: '{pids}'")
        for pid in pids:
            try:
                parcel = json.loads(
                    get_requests.pid(aois.value, year.value, pid, True))
                get_data(parcel)
            except Exception as err:
                print(err)

    @bt_get.on_click
    def bt_get_on_click(b):
        progress.clear_output()
        if method.value == 1:
            try:
                with progress:
                    get_requests = data_source()
                    lon, lat = plon.value, plat.value
                    get_from_location(lon, lat)
            except Exception as err:
                outlog(
                    f"Could not get parcel information for location '{lon}', '{lat}': {err}"
                )

        elif method.value == 2:
            try:
                with progress:
                    pids = pid.value.replace(" ", "").split(",")
                    get_from_id(pids)
            except Exception as err:
                outlog(f"Could not get parcel information: {err}")

        elif method.value == 3:
            try:
                marker = get_maps.base_map.map_marker
                lon = str(round(marker.location[1], 2))
                lat = str(round(marker.location[0], 2))
                get_from_location(lon, lat)
            except Exception as err:
                outlog(f"Could not get parcel information: {err}")
        elif method.value == 4:
            try:
                file = config.get_value(['files', 'pids_poly'])
                with open(file, "r") as text_file:
                    pids = text_file.read().split('\n')
                outlog("Geting data form the parcels:")
                outlog(pids)
                if len(pids) <= plimit:
                    get_from_id(pids)
                else:
                    outlog(
                        "You exceeded the maximum amount of selected parcels ",
                        f"({plimit}) to get data. Please select smaller area.")
            except Exception as err:
                outlog("No pids file found.", err)
        else:
            outlog(f"Please select method to get parcel information.")

    return VBox([
        info, table_options, info_method, method, method_out, info_type,
        data_types, info_get, info_paths, paths_box, bt_get, progress
    ])