Example #1
0
 def __init__(self) -> None:
     self._athlete: typing.Optional[str] = None
     self._title: typing.Optional[str] = None
     self.tracks_by_date: typing.Dict[
         str, typing.List[Track]] = defaultdict(list)
     self.tracks: typing.List[Track] = []
     self.length_range = QuantityRange()
     self.length_range_by_date = QuantityRange()
     self.total_length_year_dict: typing.Dict[
         int, pint.quantity.Quantity] = defaultdict(int)
     self.units = "metric"
     self.colors = {
         "background": "#222222",
         "text": "#FFFFFF",
         "special": "#FFFF00",
         "track": "#4DD2FF",
     }
     self.special_distance: typing.Dict[str, float] = {
         "special_distance1": 10,
         "special_distance2": 20
     }
     self.width = 200
     self.height = 300
     self.years = YearRange()
     self.tracks_drawer: typing.Optional["TracksDrawer"] = None
     self._trans: typing.Optional[typing.Callable[[str], str]] = None
     self.set_language(None)
Example #2
0
 def __init__(self) -> None:
     self._min_length: pint.quantity.Quantity = 1 * Units().km
     self.special_file_names: typing.List[str] = []
     self.year_range = YearRange()
     self.cache_dir: typing.Optional[str] = None
     self.strava_cache_file = ""
     self._cache_file_names: typing.Dict[str, str] = {}
Example #3
0
 def __init__(self, workers: typing.Optional[int]) -> None:
     self._workers = workers
     self._min_length: pint.quantity.Quantity = 1 * Units().km
     self.special_file_names: typing.List[str] = []
     self.year_range = YearRange()
     self.cache_dir: typing.Optional[str] = None
     self.strava_cache_file = ""
     self._cache_file_names: typing.Dict[str, str] = {}
     self._activity_type: str = "all"
Example #4
0
 def __compute_years(self, tracks):
     if self.years is not None:
         return
     self.years = YearRange()
     for t in tracks:
         self.years.add(t.start_time)
Example #5
0
class Poster:
    """Create a poster from track data.

    Attributes:
        athlete: Name of athlete to be displayed on poster.
        title: Title of poster.
        tracks_by_date: Tracks organized temporally if needed.
        tracks: List of tracks to be used in the poster.
        length_range: Range of lengths of tracks in poster.
        length_range_by_date: Range of lengths organized temporally.
        units: Length units to be used in poster.
        colors: Colors for various components of the poster.
        width: Poster width.
        height: Poster height.
        years: Years included in the poster.
        tracks_drawer: drawer used to draw the poster.

    Methods:
        set_tracks: Associate the Poster with a set of tracks
        draw: Draw the tracks on the poster.
        m2u: Convert meters to kilometers or miles based on units
        u: Return distance unit (km or mi)
    """
    def __init__(self):
        self.athlete = None
        self.title = None
        self.tracks_by_date = {}
        self.tracks = []
        self.length_range = None
        self.length_range_by_date = None
        self.units = "metric"
        self.colors = {
            "background": "#222222",
            "text": "#FFFFFF",
            "special": "#FFFF00",
            "track": "#4DD2FF",
        }
        self.special_distance = {
            "special_distance1": "10",
            "special_distance2": "20"
        }
        self.width = 200
        self.height = 300
        self.years = None
        self.tracks_drawer = None
        self.trans = None
        self.set_language(None)

    def set_language(self, language):
        if language:
            try:
                locale.setlocale(locale.LC_ALL, f"{language}.utf8")
            except locale.Error as e:
                print(f'Cannot set locale to "{language}": {e}')
                language = None
                pass

        # Fall-back to NullTranslations, if the specified language translation cannot be found.
        if language:
            lang = gettext.translation("gpxposter",
                                       localedir="locale",
                                       languages=[language],
                                       fallback=True)
        else:
            lang = gettext.NullTranslations()
        self.trans = lang.gettext

    def set_tracks(self, tracks):
        """Associate the set of tracks with this poster.

        In addition to setting self.tracks, also compute the necessary attributes for the Poster
        based on this set of tracks.
        """
        self.tracks = tracks
        self.tracks_by_date = {}
        self.length_range = ValueRange()
        self.length_range_by_date = ValueRange()
        self.__compute_years(tracks)
        for track in tracks:
            if not self.years.contains(track.start_time):
                continue
            text_date = track.start_time.strftime("%Y-%m-%d")
            if text_date in self.tracks_by_date:
                self.tracks_by_date[text_date].append(track)
            else:
                self.tracks_by_date[text_date] = [track]
            self.length_range.extend(track.length)
        for tracks in self.tracks_by_date.values():
            length = sum([t.length for t in tracks])
            self.length_range_by_date.extend(length)

    def draw(self, drawer, output):
        """Set the Poster's drawer and draw the tracks."""
        self.tracks_drawer = drawer
        d = svgwrite.Drawing(output, (f"{self.width}mm", f"{self.height}mm"))
        d.viewbox(0, 0, self.width, self.height)
        d.add(
            d.rect((0, 0), (self.width, self.height),
                   fill=self.colors["background"]))
        self.__draw_header(d)
        self.__draw_footer(d)
        self.__draw_tracks(d, XY(self.width - 20, self.height - 30 - 30),
                           XY(10, 30))
        d.save()

    def m2u(self, m):
        """Convert meters to kilometers or miles, according to units."""
        if self.units == "metric":
            return 0.001 * m
        return 0.001 * m / 1.609344

    def u(self):
        """Return the unit of distance being used on the Poster."""
        if self.units == "metric":
            return "km"
        return "mi"

    def format_distance(self, d: float) -> str:
        """Formats a distance using the locale specific float format and the selected unit."""
        return format_float(self.m2u(d)) + " " + self.u()

    def __draw_tracks(self, d, size: XY, offset: XY):
        self.tracks_drawer.draw(d, size, offset)

    def __draw_header(self, d):
        text_color = self.colors["text"]
        title_style = "font-size:12px; font-family:Arial; font-weight:bold;"
        d.add(
            d.text(self.title,
                   insert=(10, 20),
                   fill=text_color,
                   style=title_style))

    def __draw_footer(self, d):
        text_color = self.colors["text"]
        header_style = "font-size:4px; font-family:Arial"
        value_style = "font-size:9px; font-family:Arial"
        small_value_style = "font-size:3px; font-family:Arial"

        (
            total_length,
            average_length,
            min_length,
            max_length,
            weeks,
        ) = self.__compute_track_statistics()

        d.add(
            d.text(
                self.trans("ATHLETE"),
                insert=(10, self.height - 20),
                fill=text_color,
                style=header_style,
            ))
        d.add(
            d.text(
                self.athlete,
                insert=(10, self.height - 10),
                fill=text_color,
                style=value_style,
            ))
        d.add(
            d.text(
                self.trans("STATISTICS"),
                insert=(120, self.height - 20),
                fill=text_color,
                style=header_style,
            ))
        d.add(
            d.text(
                self.trans("Number") + f": {len(self.tracks)}",
                insert=(120, self.height - 15),
                fill=text_color,
                style=small_value_style,
            ))
        d.add(
            d.text(
                self.trans("Weekly") + ": " +
                format_float(len(self.tracks) / weeks),
                insert=(120, self.height - 10),
                fill=text_color,
                style=small_value_style,
            ))
        d.add(
            d.text(
                self.trans("Total") + ": " +
                self.format_distance(total_length),
                insert=(139, self.height - 15),
                fill=text_color,
                style=small_value_style,
            ))
        d.add(
            d.text(
                self.trans("Avg") + ": " +
                self.format_distance(average_length),
                insert=(139, self.height - 10),
                fill=text_color,
                style=small_value_style,
            ))
        d.add(
            d.text(
                self.trans("Min") + ": " + self.format_distance(min_length),
                insert=(167, self.height - 15),
                fill=text_color,
                style=small_value_style,
            ))
        d.add(
            d.text(
                self.trans("Max") + ": " + self.format_distance(max_length),
                insert=(167, self.height - 10),
                fill=text_color,
                style=small_value_style,
            ))

    def __compute_track_statistics(self):
        length_range = ValueRange()
        total_length = 0
        total_length_year_dict = defaultdict(int)
        weeks = {}
        for t in self.tracks:
            total_length += t.length
            total_length_year_dict[t.start_time.year] += t.length
            length_range.extend(t.length)
            # time.isocalendar()[1] -> week number
            weeks[(t.start_time.year, t.start_time.isocalendar()[1])] = 1
        self.total_length_year_dict = total_length_year_dict
        return (
            total_length,
            total_length / len(self.tracks),
            length_range.lower(),
            length_range.upper(),
            len(weeks),
        )

    def __compute_years(self, tracks):
        if self.years is not None:
            return
        self.years = YearRange()
        for t in tracks:
            self.years.add(t.start_time)
Example #6
0
class Poster:
    """Create a poster from track data.

    Attributes:
        athlete: Name of athlete to be displayed on poster.
        title: Title of poster.
        tracks_by_date: Tracks organized temporally if needed.
        tracks: List of tracks to be used in the poster.
        length_range: Range of lengths of tracks in poster.
        length_range_by_date: Range of lengths organized temporally.
        units: Length units to be used in poster.
        colors: Colors for various components of the poster.
        width: Poster width.
        height: Poster height.
        years: Years included in the poster.
        tracks_drawer: drawer used to draw the poster.

    Methods:
        set_tracks: Associate the Poster with a set of tracks
        draw: Draw the tracks on the poster.
        m2u: Convert meters to kilometers or miles based on units
        u: Return distance unit (km or mi)
    """

    def __init__(self) -> None:
        self._athlete: typing.Optional[str] = None
        self._title: typing.Optional[str] = None
        self.tracks_by_date: typing.Dict[str, typing.List[Track]] = defaultdict(list)
        self.tracks: typing.List[Track] = []
        self.length_range = QuantityRange()
        self.length_range_by_date = QuantityRange()
        self.total_length_year_dict: typing.Dict[int, pint.quantity.Quantity] = defaultdict(int)
        self.units = "metric"
        self.colors = {
            "background": "#222222",
            "text": "#FFFFFF",
            "special": "#FFFF00",
            "track": "#4DD2FF",
        }
        self.special_distance: typing.Dict[str, float] = {"special_distance1": 10, "special_distance2": 20}
        self.width = 200
        self.height = 300
        self.years = YearRange()
        self.tracks_drawer: typing.Optional["TracksDrawer"] = None
        self._trans: typing.Optional[typing.Callable[[str], str]] = None
        self.set_language(None, None)

    def set_language(self, language: typing.Optional[str], localedir: typing.Optional[str]) -> None:
        if language:
            try:
                locale.setlocale(locale.LC_ALL, f"{language}.utf8")
            except locale.Error as e:
                log.warning("Unable to set the locale to %s (%s)", language, str(e))
                language = None

        # Fall-back to NullTranslations, if the specified language translation cannot be found.
        if language:
            lang = gettext.translation("gpxposter", localedir=localedir, languages=[language], fallback=True)
            if len(lang.info()) == 0:
                log.warning(
                    "Unable to load translations for %s from %s; falling back to the default translation.",
                    language,
                    localedir if localedir else "the system's default locale directory",
                )
        else:
            lang = gettext.NullTranslations()
        self._trans = lang.gettext

    def translate(self, s: str) -> str:
        if self._trans is None:
            return s
        return self._trans(s)

    def month_name(self, month: int) -> str:
        assert 1 <= month <= 12

        return [
            self.translate("January"),
            self.translate("February"),
            self.translate("March"),
            self.translate("April"),
            self.translate("May"),
            self.translate("June"),
            self.translate("July"),
            self.translate("August"),
            self.translate("September"),
            self.translate("October"),
            self.translate("November"),
            self.translate("December"),
        ][month - 1]

    def set_athlete(self, athlete: str) -> None:
        self._athlete = athlete

    def set_title(self, title: str) -> None:
        self._title = title

    def set_tracks(self, tracks: typing.List[Track]) -> None:
        """Associate the set of tracks with this poster.

        In addition to setting self.tracks, also compute the necessary attributes for the Poster
        based on this set of tracks.
        """
        self.tracks = tracks
        self.tracks_by_date.clear()
        self.length_range.clear()
        self.length_range_by_date.clear()
        self._compute_years(tracks)
        for track in tracks:
            if not self.years.contains(track.start_time()):
                continue
            text_date = track.start_time().strftime("%Y-%m-%d")
            self.tracks_by_date[text_date].append(track)
            self.length_range.extend(track.length())
        for date_tracks in self.tracks_by_date.values():
            length = sum([t.length() for t in date_tracks])
            self.length_range_by_date.extend(length)

    def draw(self, drawer: "TracksDrawer", output: str) -> None:
        """Set the Poster's drawer and draw the tracks."""
        self.tracks_drawer = drawer
        d = svgwrite.Drawing(output, (f"{self.width}mm", f"{self.height}mm"))
        d.viewbox(0, 0, self.width, self.height)
        d.add(d.rect((0, 0), (self.width, self.height), fill=self.colors["background"]))
        self._draw_header(d)
        self._draw_footer(d)
        self._draw_tracks(d, XY(self.width - 20, self.height - 30 - 30), XY(10, 30))
        d.save()

    def m2u(self, m: pint.quantity.Quantity) -> float:
        """Convert meters to kilometers or miles, according to units."""
        if self.units == "metric":
            return m.m_as(Units().km)
        return m.m_as(Units().mile)

    def u(self) -> str:
        """Return the unit of distance being used on the Poster."""
        if self.units == "metric":
            return self.translate("km")
        return self.translate("mi")

    def format_distance(self, d: pint.quantity.Quantity) -> str:
        """Formats a distance using the locale specific float format and the selected unit."""
        return format_float(self.m2u(d)) + " " + self.u()

    def _draw_tracks(self, d: svgwrite.Drawing, size: XY, offset: XY) -> None:
        assert self.tracks_drawer

        g = d.g(id="tracks")
        d.add(g)

        self.tracks_drawer.draw(d, g, size, offset)

    def _draw_header(self, d: svgwrite.Drawing) -> None:
        g = d.g(id="header")
        d.add(g)

        text_color = self.colors["text"]
        title_style = "font-size:12px; font-family:Arial; font-weight:bold;"
        assert self._title is not None
        g.add(d.text(self._title, insert=(10, 20), fill=text_color, style=title_style))

    def _draw_footer(self, d: svgwrite.Drawing) -> None:
        g = d.g(id="footer")
        d.add(g)

        text_color = self.colors["text"]
        header_style = "font-size:4px; font-family:Arial"
        value_style = "font-size:9px; font-family:Arial"
        small_value_style = "font-size:3px; font-family:Arial"

        (
            total_length,
            average_length,
            length_range,
            weeks,
        ) = self._compute_track_statistics()

        g.add(
            d.text(
                self.translate("ATHLETE"),
                insert=(10, self.height - 20),
                fill=text_color,
                style=header_style,
            )
        )
        g.add(
            d.text(
                self._athlete,
                insert=(10, self.height - 10),
                fill=text_color,
                style=value_style,
            )
        )
        g.add(
            d.text(
                self.translate("STATISTICS"),
                insert=(120, self.height - 20),
                fill=text_color,
                style=header_style,
            )
        )
        g.add(
            d.text(
                self.translate("Number") + f": {len(self.tracks)}",
                insert=(120, self.height - 15),
                fill=text_color,
                style=small_value_style,
            )
        )
        g.add(
            d.text(
                self.translate("Weekly") + ": " + format_float(len(self.tracks) / weeks),
                insert=(120, self.height - 10),
                fill=text_color,
                style=small_value_style,
            )
        )
        g.add(
            d.text(
                self.translate("Total") + ": " + self.format_distance(total_length),
                insert=(141, self.height - 15),
                fill=text_color,
                style=small_value_style,
            )
        )
        g.add(
            d.text(
                self.translate("Avg") + ": " + self.format_distance(average_length),
                insert=(141, self.height - 10),
                fill=text_color,
                style=small_value_style,
            )
        )
        if length_range.is_valid():
            min_length = length_range.lower()
            max_length = length_range.upper()
            assert min_length is not None
            assert max_length is not None
        else:
            min_length = 0.0
            max_length = 0.0
        g.add(
            d.text(
                self.translate("Min") + ": " + self.format_distance(min_length),
                insert=(167, self.height - 15),
                fill=text_color,
                style=small_value_style,
            )
        )
        g.add(
            d.text(
                self.translate("Max") + ": " + self.format_distance(max_length),
                insert=(167, self.height - 10),
                fill=text_color,
                style=small_value_style,
            )
        )

    def _compute_track_statistics(
        self,
    ) -> typing.Tuple[pint.quantity.Quantity, pint.quantity.Quantity, QuantityRange, int]:
        length_range = QuantityRange()
        total_length = 0.0 * Units().meter
        self.total_length_year_dict.clear()
        weeks = {}
        for t in self.tracks:
            total_length += t.length()
            self.total_length_year_dict[t.start_time().year] += t.length()
            length_range.extend(t.length())
            # time.isocalendar()[1] -> week number
            weeks[(t.start_time().year, t.start_time().isocalendar()[1])] = 1
        return (
            total_length,
            total_length / len(self.tracks),
            length_range,
            len(weeks),
        )

    def _compute_years(self, tracks: typing.List[Track]) -> None:
        self.years.clear()
        for t in tracks:
            self.years.add(t.start_time())
Example #7
0
class TrackLoader:
    """Handle the loading of tracks from cache and/or GPX files

    Attributes:
        min_length: All tracks shorter than this value are filtered out.
        special_file_names: Tracks marked as special in command line args
        year_range: All tracks outside of this range will be filtered out.
        cache_dir: Directory used to store cached tracks

    Methods:
        clear_cache: Remove cache directory
        load_tracks: Load all data from cache and GPX files
    """
    def __init__(self, workers: typing.Optional[int]) -> None:
        self._workers = workers
        self._min_length: pint.quantity.Quantity = 1 * Units().km
        self.special_file_names: typing.List[str] = []
        self.year_range = YearRange()
        self.cache_dir: typing.Optional[str] = None
        self.strava_cache_file = ""
        self._cache_file_names: typing.Dict[str, str] = {}

    def set_cache_dir(self, cache_dir: str) -> None:
        self.cache_dir = cache_dir

    def clear_cache(self) -> None:
        """Remove cache directory, if it exists"""
        if self.cache_dir is not None and os.path.isdir(self.cache_dir):
            log.info("Removing cache dir: %s", self.cache_dir)
            try:
                shutil.rmtree(self.cache_dir)
            except OSError as e:
                log.error("Failed: %s", str(e))

    def set_min_length(self, min_length: pint.quantity.Quantity) -> None:
        self._min_length = min_length

    def load_tracks(self, base_dir: str) -> typing.List[Track]:
        """Load tracks base_dir and return as a List of tracks"""
        file_names = list(self._list_gpx_files(base_dir))
        log.info("GPX files: %d", len(file_names))

        tracks: typing.List[Track] = []

        # load track from cache
        cached_tracks: typing.Dict[str, Track] = {}
        if self.cache_dir:
            log.info("Trying to load %d track(s) from cache...",
                     len(file_names))
            cached_tracks = self._load_tracks_from_cache(file_names)
            log.info("Loaded tracks from cache: %d", len(cached_tracks))
            tracks = list(cached_tracks.values())

        # load remaining gpx files
        remaining_file_names = [
            f for f in file_names if f not in cached_tracks
        ]
        if remaining_file_names:
            log.info(
                "Trying to load %d track(s) from GPX files; this may take a while...",
                len(remaining_file_names))
            timezone_adjuster = TimezoneAdjuster()
            loaded_tracks = self._load_tracks(remaining_file_names,
                                              timezone_adjuster)
            tracks.extend(loaded_tracks.values())
            log.info("Conventionally loaded tracks: %d", len(loaded_tracks))
            self._store_tracks_to_cache(loaded_tracks)

        return self._filter_and_merge_tracks(tracks)

    def load_strava_tracks(self, strava_config: str) -> typing.List[Track]:
        tracks = []
        tracks_names = []
        if self.cache_dir:
            self.strava_cache_file = os.path.join(self.cache_dir,
                                                  strava_config)
            if os.path.isfile(self.strava_cache_file):
                with open(self.strava_cache_file) as f:
                    strava_cache_data = json.load(f)
                    tracks = [
                        self._strava_cache_to_track(i)
                        for i in strava_cache_data
                    ]
                    tracks_names = [track.file_names[0] for track in tracks]

        with open(strava_config) as f:
            strava_data = json.load(f)
        filter_type = strava_data.pop("activity_type", None)
        client = Client()
        response = client.refresh_access_token(**strava_data)
        client.access_token = response["access_token"]
        filter_dict = {"before": datetime.datetime.utcnow()}
        if tracks:
            max_time = max(track.start_time() for track in tracks)
            filter_dict = {"after": max_time - datetime.timedelta(days=2)}
        for activity in client.get_activities(**filter_dict):
            # tricky to pass the timezone
            if str(activity.id) in tracks_names:
                continue
            if filter_type and activity.type not in ([
                    filter_type
            ] if isinstance(filter_type, str) else filter_type):  # pylint: disable=superfluous-parens
                continue
            t = Track()
            t.load_strava(activity)
            tracks.append(t)
        self._store_strava_tracks_to_cache(tracks)
        return self._filter_and_merge_tracks(tracks)

    def _filter_tracks(self, tracks: typing.List[Track]) -> typing.List[Track]:
        filtered_tracks = []
        for t in tracks:
            file_name = t.file_names[0]
            if t.length().magnitude == 0:
                log.info("%s: skipping empty track", file_name)
            elif not t.has_time():
                log.info("%s: skipping track without start or end time",
                         file_name)
            elif not self.year_range.contains(t.start_time()):
                log.info("%s: skipping track with wrong year %d", file_name,
                         t.start_time().year)
            else:
                t.special = file_name in self.special_file_names
                filtered_tracks.append(t)
        return filtered_tracks

    def _filter_and_merge_tracks(
            self, tracks: typing.List[Track]) -> typing.List[Track]:
        tracks = self._filter_tracks(tracks)
        # merge tracks that took place within one hour
        tracks = self._merge_tracks(tracks)
        # filter out tracks with length < min_length
        return [t for t in tracks if t.length() >= self._min_length]

    @staticmethod
    def _merge_tracks(tracks: typing.List[Track]) -> typing.List[Track]:
        log.info("Merging tracks...")
        tracks = sorted(tracks, key=lambda t1: t1.start_time())
        merged_tracks = []
        last_end_time = None
        for t in tracks:
            if last_end_time is None:
                merged_tracks.append(t)
            else:
                dt = (t.start_time() - last_end_time).total_seconds()
                if 0 < dt < 3600:
                    merged_tracks[-1].append(t)
                else:
                    merged_tracks.append(t)
            last_end_time = t.end_time()
        log.info("Merged %d track(s)", len(tracks) - len(merged_tracks))
        return merged_tracks

    def _load_tracks(
            self, file_names: typing.List[str],
            timezone_adjuster: TimezoneAdjuster) -> typing.Dict[str, Track]:
        tracks = {}

        if self._workers is not None and self._workers <= 1:
            for file_name in file_names:
                try:
                    t = load_gpx_file(file_name, timezone_adjuster)
                except TrackLoadError as e:
                    log.error("Error while loading %s: %s", file_name, str(e))
                else:
                    tracks[file_name] = t
            return tracks

        with concurrent.futures.ProcessPoolExecutor(
                max_workers=self._workers) as executor:
            future_to_file_name = {
                executor.submit(load_gpx_file, file_name, timezone_adjuster):
                file_name
                for file_name in file_names
            }
        for future in concurrent.futures.as_completed(future_to_file_name):
            file_name = future_to_file_name[future]
            try:
                t = future.result()
            except TrackLoadError as e:
                log.error("Error while loading %s: %s", file_name, str(e))
            else:
                tracks[file_name] = t

        return tracks

    def _load_tracks_from_cache(
            self, file_names: typing.List[str]) -> typing.Dict[str, Track]:
        tracks = {}

        if self._workers is not None and self._workers <= 1:
            for file_name in file_names:
                try:
                    t = load_cached_track_file(
                        self._get_cache_file_name(file_name), file_name)
                except Exception:
                    # silently ignore failed cache load attempts
                    pass
                else:
                    tracks[file_name] = t
            return tracks

        with concurrent.futures.ProcessPoolExecutor(
                max_workers=self._workers) as executor:
            future_to_file_name = {
                executor.submit(load_cached_track_file,
                                self._get_cache_file_name(file_name),
                                file_name): file_name
                for file_name in file_names
            }
        for future in concurrent.futures.as_completed(future_to_file_name):
            file_name = future_to_file_name[future]
            try:
                t = future.result()
            except Exception:
                # silently ignore failed cache load attempts
                pass
            else:
                tracks[file_name] = t

        return tracks

    def _store_tracks_to_cache(self, tracks: typing.Dict[str, Track]) -> None:
        if (not tracks) or (not self.cache_dir):
            return

        log.info("Storing %d track(s) to cache...", len(tracks))
        for (file_name, t) in tracks.items():
            try:
                t.store_cache(self._get_cache_file_name(file_name))
            except Exception as e:
                log.error("Failed to store track %s to cache: %s", file_name,
                          str(e))
            else:
                log.info("Stored track %s to cache", file_name)

    def _store_strava_tracks_to_cache(self,
                                      tracks: typing.List[Track]) -> None:
        if (not tracks) or (not self.cache_dir):
            return
        dirname = os.path.dirname(self.strava_cache_file)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        log.info("Storing %d track(s) to cache...", len(tracks))
        to_cache_tracks = [
            self._make_strava_cache_dict(track) for track in tracks
        ]
        with open(self.strava_cache_file, "w") as f:
            json.dump(to_cache_tracks, f)

    @staticmethod
    def _make_strava_cache_dict(track: Track) -> typing.Dict[str, Any]:
        lines_data = []
        for line in track.polylines:
            lines_data.append([{
                "lat": latlng.lat().degrees,
                "lng": latlng.lng().degrees
            } for latlng in line])
        return {
            "name": track.file_names[0],  # strava id
            "start": track.start_time().strftime("%Y-%m-%d %H:%M:%S"),
            "end": track.end_time().strftime("%Y-%m-%d %H:%M:%S"),
            "length": track.length_meters,
            "segments": lines_data,
        }

    @staticmethod
    def _strava_cache_to_track(data: typing.Dict[str, Any]) -> "Track":
        t = Track()
        t.file_names = [data["name"]]
        t.set_start_time(
            datetime.datetime.strptime(data["start"], "%Y-%m-%d %H:%M:%S"))
        t.set_end_time(
            datetime.datetime.strptime(data["end"], "%Y-%m-%d %H:%M:%S"))
        t.length_meters = float(data["length"])
        t.polylines = []
        for data_line in data["segments"]:
            t.polylines.append([
                s2sphere.LatLng.from_degrees(float(d["lat"]), float(d["lng"]))
                for d in data_line
            ])
        return t

    @staticmethod
    def _list_gpx_files(base_dir: str) -> typing.Generator[str, None, None]:
        base_dir = os.path.abspath(base_dir)
        if not os.path.isdir(base_dir):
            raise ParameterError(f"Not a directory: {base_dir}")
        for name in os.listdir(base_dir):
            path_name = os.path.join(base_dir, name)
            if name.endswith(".gpx") and os.path.isfile(path_name):
                yield path_name

    def _get_cache_file_name(self, file_name: str) -> str:
        assert self.cache_dir

        if file_name in self._cache_file_names:
            return self._cache_file_names[file_name]

        try:
            checksum = hashlib.sha256(open(file_name, "rb").read()).hexdigest()
        except PermissionError as e:
            raise TrackLoadError(
                "Failed to compute checksum (bad permissions).") from e
        except Exception as e:
            raise TrackLoadError("Failed to compute checksum.") from e

        cache_file_name = os.path.join(self.cache_dir, f"{checksum}.json")
        self._cache_file_names[file_name] = cache_file_name
        return cache_file_name
 def __init__(self):
     self.min_length = 1000
     self.special_file_names = []
     self.year_range = YearRange()
     self.cache_dir = None
     self._cache_file_names = {}
class TrackLoader:
    """Handle the loading of tracks from cache and/or GPX files

    Attributes:
        min_length: All tracks shorter than this value are filtered out.
        special_file_names: Tracks marked as special in command line args
        year_range: All tracks outside of this range will be filtered out.
        cache_dir: Directory used to store cached tracks

    Methods:
        clear_cache: Remove cache directory
        load_tracks: Load all data from cache and GPX files
    """
    def __init__(self):
        self.min_length = 1000
        self.special_file_names = []
        self.year_range = YearRange()
        self.cache_dir = None
        self._cache_file_names = {}

    def clear_cache(self):
        """Remove cache directory, if it exists"""
        if os.path.isdir(self.cache_dir):
            log.info(f"Removing cache dir: {self.cache_dir}")
            try:
                shutil.rmtree(self.cache_dir)
            except OSError as e:
                log.error(f"Failed: {e}")

    def load_tracks(self, base_dir: str) -> List[Track]:
        """Load tracks base_dir and return as a List of tracks"""
        file_names = [x for x in self._list_gpx_files(base_dir)]
        log.info(f"GPX files: {len(file_names)}")

        tracks = []  # type: List[Track]

        # load track from cache
        cached_tracks = {}  # type: Dict[str, Track]
        if self.cache_dir:
            log.info(
                f"Trying to load {len(file_names)} track(s) from cache...")
            cached_tracks = self._load_tracks_from_cache(file_names)
            log.info(f"Loaded tracks from cache: {len(cached_tracks)}")
            tracks = list(cached_tracks.values())

        # load remaining gpx files
        remaining_file_names = [
            f for f in file_names if f not in cached_tracks
        ]
        if remaining_file_names:
            log.info(
                f"Trying to load {len(remaining_file_names)} track(s) from GPX files; this may take a while..."
            )
            loaded_tracks = self._load_tracks(remaining_file_names)
            tracks.extend(loaded_tracks.values())
            log.info(f"Conventionally loaded tracks: {len(loaded_tracks)}")
            self._store_tracks_to_cache(loaded_tracks)

        tracks = self._filter_tracks(tracks)

        # merge tracks that took place within one hour
        tracks = self._merge_tracks(tracks)
        # filter out tracks with length < min_length
        return [t for t in tracks if t.length >= self.min_length]

    def _filter_tracks(self, tracks: List[Track]) -> List[Track]:
        filtered_tracks = []
        for t in tracks:
            file_name = t.file_names[0]
            if t.length == 0:
                log.info(f"{file_name}: skipping empty track")
            elif not t.start_time:
                log.info(f"{file_name}: skipping track without start time")
            elif not self.year_range.contains(t.start_time):
                log.info(
                    f"{file_name}: skipping track with wrong year {t.start_time.year}"
                )
            else:
                t.special = file_name in self.special_file_names
                filtered_tracks.append(t)
        return filtered_tracks

    @staticmethod
    def _merge_tracks(tracks: List[Track]) -> List[Track]:
        log.info("Merging tracks...")
        tracks = sorted(tracks, key=lambda t1: t1.start_time)
        merged_tracks = []
        last_end_time = None
        for t in tracks:
            if last_end_time is None:
                merged_tracks.append(t)
            else:
                dt = (t.start_time - last_end_time).total_seconds()
                if 0 < dt < 3600:
                    merged_tracks[-1].append(t)
                else:
                    merged_tracks.append(t)
            last_end_time = t.end_time
        log.info(f"Merged {len(tracks) - len(merged_tracks)} track(s)")
        return merged_tracks

    @staticmethod
    def _load_tracks(file_names: List[str]) -> Dict[str, Track]:
        tracks = {}
        with concurrent.futures.ProcessPoolExecutor() as executor:
            future_to_file_name = {
                executor.submit(load_gpx_file, file_name): file_name
                for file_name in file_names
            }
        for future in concurrent.futures.as_completed(future_to_file_name):
            file_name = future_to_file_name[future]
            try:
                t = future.result()
            except TrackLoadError as e:
                log.error(f"Error while loading {file_name}: {e}")
            else:
                tracks[file_name] = t

        return tracks

    def _load_tracks_from_cache(self,
                                file_names: List[str]) -> Dict[str, Track]:
        tracks = {}
        with concurrent.futures.ProcessPoolExecutor() as executor:
            future_to_file_name = {
                executor.submit(
                    load_cached_track_file,
                    self._get_cache_file_name(file_name),
                    file_name,
                ): file_name
                for file_name in file_names
            }
        for future in concurrent.futures.as_completed(future_to_file_name):
            file_name = future_to_file_name[future]
            try:
                t = future.result()
            except Exception:
                # silently ignore failed cache load attempts
                pass
            else:
                tracks[file_name] = t
        return tracks

    def _store_tracks_to_cache(self, tracks: Dict[str, Track]):
        if (not tracks) or (not self.cache_dir):
            return

        log.info(f"Storing {len(tracks)} track(s) to cache...")
        for (file_name, t) in tracks.items():
            try:
                t.store_cache(self._get_cache_file_name(file_name))
            except Exception as e:
                log.error(f"Failed to store track {file_name} to cache: {e}")
            else:
                log.info(f"Stored track {file_name} to cache")

    @staticmethod
    def _list_gpx_files(base_dir: str) -> Generator[str, None, None]:
        base_dir = os.path.abspath(base_dir)
        if not os.path.isdir(base_dir):
            raise ParameterError(f"Not a directory: {base_dir}")
        for name in os.listdir(base_dir):
            path_name = os.path.join(base_dir, name)
            if name.endswith(".gpx") and os.path.isfile(path_name):
                yield path_name

    def _get_cache_file_name(self, file_name: str) -> str:
        assert self.cache_dir

        if file_name in self._cache_file_names:
            return self._cache_file_names[file_name]

        try:
            checksum = hashlib.sha256(open(file_name, "rb").read()).hexdigest()
        except PermissionError as e:
            raise TrackLoadError(
                "Failed to compute checksum (bad permissions).") from e
        except Exception as e:
            raise TrackLoadError("Failed to compute checksum.") from e

        cache_file_name = os.path.join(self.cache_dir, f"{checksum}.json")
        self._cache_file_names[file_name] = cache_file_name
        return cache_file_name