コード例 #1
0
ファイル: main.py プロジェクト: davidarnarsson/hello-python
class Sequencer:
    sortkey = lambda n: n.start + n.length

    def __init__(self):
        self.notes = IntervalTree()

    def add(self, note):
        self.notes.addi(note.start, note.start + note.length, note)

    def remove(self, note):
        self.notes.removei(note.start, note.start + note.length, note)

    def length(self):
        return self.notes.end()

    def sample_at(self, t):

        # again, bad
        current = self.notes.at(t)

        acc = 0
        for note in current:
            note_pos = t - note.begin
            acc += (osc.sine(note_pos, note.data.pitch) * note.data.velocity *
                    adsr(note_pos, note.end - note.begin)) * (1 / len(current))

        return acc
コード例 #2
0
ファイル: pointer.py プロジェクト: octorock/the-little-hat
class PointerList:
    def __init__(self, pointers: List[Pointer], rom_variant: RomVariant) -> None:
        intervals = []

        for pointer in pointers:
            if pointer.rom_variant == rom_variant:
                intervals.append(Interval(pointer.address, pointer.address+4, pointer))

        self.tree = IntervalTree(intervals)

    def get_pointers_at(self, index: int) -> List[Pointer]:
        pointers = []
        for interval in self.tree.at(index):
            pointers.append(interval.data)
        return pointers

    def append(self, pointer: Pointer) -> None:
        self.tree.add(Interval(pointer.address, pointer.address+4, pointer))

    def remove(self, pointer: Pointer) -> None:
        self.tree.remove(Interval(pointer.address, pointer.address+4, pointer))

    def __iter__(self):
        return map(lambda x: x.data, self.tree.__iter__())

    def get_sorted_pointers(self) -> List[Pointer]:
        return map(lambda x: x.data, sorted(self.tree))
コード例 #3
0
 def assign_breakpoint_to_genes(self, bp: tuple):
     chromosome = bp[0]
     for gene in [self.prime5, self.prime3]:
         if chromosome == gene.chromosome:
             gene_extremes = IntervalTree()
             gene_extremes.add(
                 Interval(gene.exons.begin(), gene.exons.end()))
             if gene_extremes.at(bp[1]):
                 gene.breakpoint = bp[1]
コード例 #4
0
 def assign_breakpoint(self, bp, chrom):
     for gene in [self.prime5, self.prime3]:
         gene_coords = gene.exons
         gene_ends = IntervalTree()
         gene_ends.add(Interval(gene_coords.begin(), gene_coords.end()))
         if gene_ends.at(bp):
             gene.breakpoint = bp
             print("Breakpoint", bp, "assigned to", gene.gene_name)
             break
         else:
             print("Breakpoint", bp, "is outside gene", gene.gene_name)
     # there can be a chance that the breakpoint is outside of both genes, assign it to the closest one
     for gene in [self.prime5, self.prime3]:
         if gene.breakpoint == 0 and gene.chromosome == chrom and (
                 abs(gene.exons.begin() - bp) < 20000
                 or abs(gene.exons.end() - bp) < 20000):
             gene.breakpoint = bp
             print("Breakpoint", bp, "assigned to", gene.gene_name)
コード例 #5
0
class AnnotationList:
    def __init__(self, annotations: List[Annotation],
                 rom_variant: RomVariant) -> None:
        intervals = []

        for annotation in annotations:
            if annotation.rom_variant == rom_variant:
                intervals.append(
                    Interval(annotation.address,
                             annotation.address + annotation.length,
                             annotation))

        self.tree = IntervalTree(intervals)

    def get_annotations_at(self, index: int) -> List[Annotation]:
        annotations = []
        for interval in self.tree.at(index):
            annotations.append(interval.data)
        return annotations
コード例 #6
0
                    continue
            cds_all = IntervalTree()
            cds_gene = IntervalTree()
            histogram_all = Counter()
            histogram_gene = dict()
            for gene_name, gene_info in genes.items():
                gene_info['cds'].merge_overlaps()
                histogram_gene[gene_info['id']] = Counter()
                for interval in gene_info['cds']:
                    cds_all.addi(interval.begin, interval.end)
                    cds_gene.addi(interval.begin, interval.end, gene_info['id'])
            cds_all.merge_overlaps()
            for interval in cds_all:
                records = list(ifile.fetch(chrom, interval.begin - 1, interval.end - 1))
                if not records: # if nothing fetched, try adding or removing 'chr' prefix
                    if chrom.startswith('chr'):
                        records = ifile.fetch(chrom[3:], interval.begin - 1, interval.end - 1)
                    else:
                        records = ifile.fetch('chr' + chrom, interval.begin - 1, interval.end - 1)
                for record in records:
                    overlapping_genes = [x.data for x in cds_gene.at(record.pos)]
                    dps = [fmt['DP'] for _, fmt in record.samples.items()]
                    histogram_all.update(dps)
                    for overlapping_gene in overlapping_genes:
                        histogram_gene[overlapping_gene].update(dps)
            if len(histogram_all) > 0:
                jsons.append({'chrom': chrom, 'type': 'all', 'histogram': histogram_all})
                for gene_id, gene_histogram in histogram_gene.items():
                    jsons.append({'chrom': chrom, 'type': 'gene', 'gene_id': gene_id, 'histogram': gene_histogram})
        json.dump(jsons, ofile)
コード例 #7
0
ファイル: time.py プロジェクト: gnoose/datafeeds-shared
class DateIntervalTree:
    """A slight adaption of the intervaltree library to support python dates

    The intervaltree data structure stores integer ranges, fundamentally. Therefore, if we want to
    store dates, we must fist convert them to integers, in a way that preserves inequalities.
    Luckily, the toordinal() function on datetime.date satisfies this requirement.

    It's important to note that this interval tree structure is, unless otherwise noted inclusive of
    lower bounds and exclusive of upper bounds. That is to say, an interval from A to B includes the
    value A and excludes the value B.
    """
    def __init__(self):
        self.tree = IntervalTree()

    @staticmethod
    def to_date_interval(begin: date, end: date, data: Any) -> Interval:
        """Convert a date interval (and associated date, if any) into an ordinal interval"""
        return Interval(begin.toordinal(), end.toordinal(), data)

    @staticmethod
    def from_date_interval(ival: Interval) -> Interval:
        """Convert an ordinal interval to a date interval"""
        return Interval(date.fromordinal(ival.begin),
                        date.fromordinal(ival.end), ival.data)

    def add(self, begin: date, end: date, data: Any = None):
        """Add a date interval to the interval tree, along with any associated date"""
        self.tree.add(DateIntervalTree.to_date_interval(begin, end, data))

    def merge_overlaps(self, reducer: Callable = None, strict: bool = True):
        """Merge overlapping date intervals in the tree.

        A reduce function can be specified to determine how data elements are combined for overlapping intervals.
        The strict argument determines whether "kissing" intervals are merged. If true (the default), only "strictly"
        overlapping intervals are merged, otherwise adjacent intervals will also be merged.

        See the intervaltree library documentation for the merge_overlaps function for a more complete description.
        """
        self.tree.merge_overlaps(data_reducer=reducer, strict=strict)

    def intervals(self) -> List[Interval]:
        """Return all date intervals in this tree"""

        # Note we convert from ordinal values to actual date objects
        return [
            DateIntervalTree.from_date_interval(ival)
            for ival in self.tree.items()
        ]

    def overlaps(self, begin: date, end: date, strict: bool = True) -> bool:
        """Determine whether the given date interval overlaps with any interval in the tree.

        According to intervaltree, intervals include the lower bound but not the upper bound:
        2015-07-23 -2015-08-21 does not overlap 2015-08-21-2015-09-21
        If strict is false, add a day to the end date to return True for single day overlaps.
        """
        if strict:
            ival = DateIntervalTree.to_date_interval(begin, end, None)
        else:
            ival = DateIntervalTree.to_date_interval(begin,
                                                     end + timedelta(days=1),
                                                     None)
        return self.tree.overlaps(ival.begin, ival.end)

    def range_query(self, begin: date, end: date) -> List[Interval]:
        """Return all intervals in the tree that strictly overlap with the given interval"""
        ival = DateIntervalTree.to_date_interval(begin, end, None)
        return [
            DateIntervalTree.from_date_interval(ival)
            for ival in self.tree.overlap(ival.begin, ival.end)
        ]

    def point_query(self, point: date) -> List[Interval]:
        return [
            DateIntervalTree.from_date_interval(ival)
            for ival in self.tree.at(point.toordinal())
        ]

    @staticmethod
    def shift_endpoints(date_tree: "DateIntervalTree") -> "DateIntervalTree":
        """Produce a new tree where adjacent intervals are guaranteed to not match at a boundary

        by shifting the end dates of touching intervals
        E.g., the intervals
            (1/1/2000, 1/10/2000), (1/10/2000, 1/20/2000)
        become
            (1/1/2000, 1/9/2000), (1/10/2000, 1/20/2000)
                         ^--A day was subtracted here to avoid matching exactly with the next interval
        Loop earliest -> latest, adjusting end date.
        """
        adjusted = DateIntervalTree()
        work_list = deque(sorted(date_tree.intervals()))
        while work_list:
            cur_ival = work_list.popleft()
            if work_list:
                next_ival = work_list[0]
                if cur_ival.end == next_ival.begin:
                    cur_ival = Interval(cur_ival.begin,
                                        cur_ival.end - timedelta(days=1),
                                        cur_ival.data)

            adjusted.add(cur_ival.begin, cur_ival.end, cur_ival.data)
        return adjusted

    @staticmethod
    def shift_endpoints_start(
            date_tree: "DateIntervalTree") -> "DateIntervalTree":
        """Produce a new tree where adjacent intervals are guaranteed to not match at a boundary

        by shifting the start dates of touching intervals
        E.g., the intervals
            (1/1/2000, 1/10/2000), (1/10/2000, 1/20/2000)
        become
            (1/1/2000, 1/10/2000), (1/11/2000, 1/20/2000)
                                      ^--A day was added here to avoid matching exactly with
                                     the next interval
        Loop latest -> earliest, adjusting start date.
        """
        adjusted = DateIntervalTree()
        work_list = deque(sorted(date_tree.intervals(), reverse=True))
        while work_list:
            cur_ival = work_list.popleft()
            if work_list:
                next_ival = work_list[0]
                if cur_ival.begin == next_ival.end:
                    log.debug(
                        "adjusting start of billing period: %s-%s",
                        cur_ival.begin,
                        cur_ival.end,
                    )
                    cur_ival = Interval(cur_ival.begin + timedelta(days=1),
                                        cur_ival.end, cur_ival.data)
            adjusted.add(cur_ival.begin, cur_ival.end, cur_ival.data)
        return adjusted

    @staticmethod
    def shift_endpoints_end(
            date_tree: "DateIntervalTree") -> "DateIntervalTree":
        """Produce a new tree where adjacent intervals are guaranteed to not match at a boundary
        by shifting the end dates of touching intervals
        E.g., the intervals
            (1/1/2000, 1/10/2000), (1/10/2000, 1/20/2000)
        become
            (1/1/2000, 1/9/2000), (1/10/2000, 1/20/2000)
                         ^--A day was subtracted here to avoid matching exactly with the next interval
        Loop latest -> earliest, adjusting end date.
        """
        adjusted = DateIntervalTree()
        work_list = deque(sorted(date_tree.intervals(), reverse=True))
        prev_ival = None
        while work_list:
            cur_ival = work_list.popleft()
            if prev_ival:
                while cur_ival.end >= prev_ival.begin:
                    new_start, new_end = (
                        cur_ival.begin,
                        cur_ival.end - timedelta(days=1),
                    )

                    if new_start == new_end:
                        # If new interval is one day long, shift start date back one day too.
                        new_start = new_start - timedelta(days=1)
                    cur_ival = Interval(new_start, new_end, cur_ival.data)
            prev_ival = cur_ival
            adjusted.add(cur_ival.begin, cur_ival.end, cur_ival.data)
        return adjusted
コード例 #8
0
class QueryCompleter(Completer):
    """
    Suggests JMESPath query syntax completions.

    After receiving AWS service and operation names in form
    of awscli command and subcommand, an output shape loaded
    from botocore Session is parsed by the ShapeParser object.
    This object returns a "Dummy response", which is used in
    attempt to provide sensible suggestions.

    At the moment, this completer is unable to provide suggestions
    for JMESPath functions and custom hashes and arrays.
    """
    def __init__(self, session, **kwds):
        self._session = Session(profile=session.profile_name)
        self._command_table = None
        self._shape_parser = ShapeParser()
        self._lexer = jmespath.lexer.Lexer()
        self._service = None
        self._operation = None
        # Attributes below change as the query changes.
        # They are used to to track state to provide suggestions.
        self._should_reparse = True
        self._shape_dict = None
        self._context = None
        self._last_pos = 0
        self._implicit_context = []
        self._stack = []
        self._tree = IntervalTree()
        self._start = 0
        self._colon = False
        self._disable = (False, 0)
        super(QueryCompleter, self).__init__(**kwds)

    @property
    def context(self):
        """
        Get the context attribute.

        This is used to track the state of mutating fake API response.
        """
        if self._context is None:
            self.context = self._shape_dict
        return self._context

    @context.setter
    def context(self, value):
        """Set the value of context attribute."""
        self._context = value

    @property
    def command_table(self):
        """
        Get the command table attribute.

        This is used to transform aws-cli command and subcommand
        into their API operation counterpart.
        """
        if self._command_table is None:
            self._command_table = build_command_table(self._session)
        return self._command_table

    def set_shape_dict(self, service, operation):
        """
        Set the fake response (shape dict).

        This is based on received aws-cli service and operation
        (command, subcommand).
        """
        shape_dict = self._get_shape_dict(service, operation)
        self._shape_dict = shape_dict
        self.context = shape_dict

    def reset(self):
        """Reset the state of the completer."""
        self.context = None
        self._implicit_context = list()
        self._stack = list()
        self._tree = IntervalTree()
        self._start = 0
        self._colon = False
        self._disable = (False, 0)
        self._repeat_suggestion = False

    def get_completions(self, document, c_e):
        """
        Retrieve suggestions for the JMESPath query.

        First parse the existing part of the query with the JMESPath
        lexer. Based on the last token type, choose appropriate
        handler method. This handler method then returns a list of
        suggested completions, which are then yielded from here.

        As the query is being parsed, the Completer
        tracks the state of the query. If the query is being
        corrected, deleted or a larger chunk is pasted at once,
        the Completer has to reparse the query (rebuild the state).
        """
        if not self._shape_dict:
            return
        if self._disable[0]:
            if document.cursor_position > self._disable[1]:
                return
            self._disable = (False, 0)

        should_repeat = not bool(document.get_word_before_cursor())
        self._repeat_suggestion = c_e.completion_requested or should_repeat
        completions = self._parse_completion(document, c_e)
        self._last_pos = document.cursor_position

        if not completions:
            return

        word = document.get_word_before_cursor(pattern=_FIND_IDENTIFIER)
        for c in sorted(completions):
            start_position = 0 if len(c) == 1 else -len(word)
            yield Completion(text_type(c), start_position=start_position)

    def _parse_completion(self, document, c_e):
        text = document.text_before_cursor
        self._text = ' ' if not text else text

        try:
            self._tokens = list(self._lexer.tokenize(self._text))
        except jmespath.exceptions.LexerError:
            return

        if self._tokens[-1]['type'] == 'eof':
            self._tokens.pop()

        if not self._tokens:
            return self.context.keys()

        if self._should_reparse:
            completions = self._reparse_completion()
            self._should_reparse = False
        elif document.cursor_position > self._last_pos:
            completions = self._append_completion()
        elif (document.cursor_position == self._last_pos
              and c_e.completion_requested):
            completions = self._append_completion()
        else:
            completions = self._reparse_completion()
            self._should_reparse = True
        return completions

    def _append_completion(self):
        last_token = self._tokens[-1]
        index = len(self._tokens) - 1
        penultimate_token = self._look_back(index, 1)
        try:
            return self._handle_token(last_token, penultimate_token, index)
        except NullIntervalException as e:
            self._disable = True, e.pos
            return

    def _reparse_completion(self):
        completions = list()
        self.reset()
        for i, token in enumerate(self._tokens):
            if self._disable[0]:
                return
            penultimate_token = self._look_back(i, 1)
            if token['type'] in COMPLEX_SIGNS:
                fake_lbracket = {
                    'type': 'lbracket',
                    'start': token['start'],
                    'end': token['end'] - 1
                }
                self._handle_token(fake_lbracket, penultimate_token, i)
            try:
                completions = self._handle_token(token, penultimate_token, i)
            except NullIntervalException as e:
                self._disable = True, e.pos
                return
        return completions

    def _handle_token(self, token, prev_token, index=None):
        if not index:
            index = len(self._tokens) - 1
        handler = getattr(self, '_handle_%s' % token['type'],
                          self._handle_others)
        return handler(token, prev_token, index)

    def _handle_lbracket(self, token, prev_token, index):
        if not prev_token:
            if isinstance(self.context, dict):
                return self.context.keys()
            return
        if not self._repeat_suggestion:
            self._switch_into_next_implicit_context(token)

        if (prev_token['type'] in IDENTIFIERS
                and isinstance(self.context, dict)):
            value = self.context.get(prev_token['value'], None)
            if isinstance(value, list):
                if not self._repeat_suggestion:
                    self.context = value
                return LBRACKETS_CONTINUATION
            self._disable = (True, token['end'])
            return

        if prev_token['type'] == 'dot':
            if isinstance(self.context, dict):
                return self.context.keys()
            self._disable = (True, token['end'])
            return

        if isinstance(self.context, list):
            return LBRACKETS_CONTINUATION

    def _handle_filter(self, token, prev_token, index):
        if self._repeat_suggestion:
            return self.context.keys()

        _, index = self._stack.pop()
        promise = token['type'], index
        self._stack.append(promise)
        if not isinstance(self.context, list):
            self._disable = (True, token['end'])
            return

        self.context = next(iter(self.context))
        if not isinstance(self.context, dict):
            self._disable = (True, token['end'])
            return

        self._implicit_context = copy.deepcopy(self.context)
        end = self._tree.end() - 1
        start = next(iter(self._tree.at(end))).begin
        self._tree[start:token['start']] = self._implicit_context
        return self.context.keys()

    def _handle_lbrace(self, token, _, index):
        if not isinstance(self.context, dict):
            self._disable = (True, token['end'])
            return
        if not self._repeat_suggestion:
            self._switch_into_next_implicit_context(token)

    def _handle_colon(self, token, prev_token, index):
        if not self._stack:
            self._disable = True, token['end']
            return
        if self._stack[-1][0] == 'lbracket':
            return
        if self._stack[-1][0] == 'lbrace':
            if not self._colon and prev_token['type'] in IDENTIFIERS:
                self._colon = True
                return self.context.keys()

    def _handle_flatten(self, token, _, index):
        if self._repeat_suggestion:
            return
        self.context = JMESPATH_FLATTEN.search(self.context)
        old_end = token['end']
        self._tree[self._start:old_end] = self._implicit_context
        _, returning_context_index = self._stack.pop()
        context_interval = next(iter(self._tree[returning_context_index]))
        self._implicit_context = context_interval.data
        self._start = old_end

    def _handle_rbracket(self, token, prev_token, index):
        if self._repeat_suggestion:
            return
        is_filter = (self._stack and self._stack[-1][0] == 'filter')
        self._switch_from_prev_implicit_context(token)
        # Handle [*] projection and index access (e.g.: lst[1])
        if prev_token and prev_token['type'] in {'star', 'number'}:
            # need antepenultimate (third to last) token info
            apu_token = self._look_back(index, 2)
            if apu_token and apu_token['type'] == 'lbracket':
                if isinstance(self.context, list):
                    self.context = next(iter(self.context))
            else:
                self._disable = (True, token['end'])
        elif prev_token and prev_token['type'] in STRINGS:
            if not is_filter:
                self._disable = (True, token['end'])

    def _handle_rbrace(self, token, _, index):
        self._disable = True, token['end']
        return

    def _handle_dot(self, token, prev_token, index):
        if not prev_token:
            self._disable = (True, token['end'])
            return
        # Applying subexpression to a JSON object
        if isinstance(self.context, dict):
            if self._repeat_suggestion:
                return self.context.keys()
            # Simulate application of * projection
            if prev_token['type'] == 'star':
                new_context = list(self.context.values())
            # Receive the value of identifier
            elif prev_token['type'] in IDENTIFIERS:
                new_context = self.context.get(prev_token['value'], None)
            elif prev_token['type'] in {'rbracket', 'flatten'}:
                new_context = self.context
            # Nothing else is applicable to JSON objects
            else:
                new_context = dict()
            self.context = new_context
            if isinstance(self.context, dict):
                return self.context.keys()
        # Applying subexpression to a JSON list
        if isinstance(self.context, list):
            if prev_token['type'] == 'flatten':
                self.context = next(iter(self.context))
                if isinstance(self.context, dict):
                    return self.context.keys()
            if prev_token['type'] == 'rbracket':
                return LBRACKET
            self._disable = (True, token['end'])

    def _handle_pipe(self, token, _, index):
        if not self._repeat_suggestion:
            if self._stack:
                pos = self._stack[-1][1]
                context_interval = next(iter((self._tree[pos])))
                context = context_interval.data
                lhs = self._text[pos:token['start']]
                tokens = list(self._lexer.tokenize(lhs))
                for a_token in reversed(tokens):
                    if a_token['type'] in {'colon', 'comma'}:
                        lhs = lhs[a_token['end']:]
                        break
            else:
                lhs = self._text[:token['start']]
                context = self._shape_dict

            tokens = list(self._lexer.tokenize(lhs))
            lhs = self._remove_filters(lhs, tokens)
            try:
                result = jmespath.search(lhs, context)
            except jmespath.exceptions.JMESPathError:
                return
            self.context = result

        if isinstance(self.context, list):
            return LBRACKET
        if isinstance(self.context, dict):
            return self.context.keys()

    def _handle_others(self, token, _, index):
        if token['type'] == 'comma':
            if self._stack and self._stack[-1][0] == 'lbrace':
                self._colon = False
                return
            if not self._stack:
                self._disable = (True, token['end'])
                return

        # Drop to fallback context on these... (&& || , > < etc...)
        if token['type'] in CONTEXT_RESET_SIGNS:
            if not self._repeat_suggestion:
                self.context = copy.deepcopy(self._implicit_context)
            if isinstance(self.context, dict):
                return self.context.keys()

        if token['type'] in IDENTIFIERS:
            if (self._stack and self._stack[-1][0] == 'lbrace'
                    and not self._colon):
                return
            identifier = token['value']
            if isinstance(self.context, dict):
                value = self.context.get(identifier, None)
                if isinstance(value, list):
                    return LBRACKET
                completions = [
                    c for c in self.context.keys() if c.startswith(identifier)
                ]
                return completions

    def _switch_into_next_implicit_context(self, token):
        old_end = token['end']
        if self._start == old_end:
            raise NullIntervalException(token['end'])
        self._implicit_context = copy.deepcopy(self.context)
        self._tree[self._start:old_end] = self._implicit_context
        self._start = old_end
        promise = token['type'], old_end - 1
        self._stack.append(promise)

    def _switch_from_prev_implicit_context(self, token):
        if (not self._stack
                or ENCLOSURE_MATCH[self._stack[-1][0]] != token['type']):
            self._disable = (True, token['end'])
            return
        old_end = token['end']
        if self._start == old_end:
            raise NullIntervalException(token['end'])
        self._tree[self._start:old_end] = self._implicit_context
        _, returning_context_index = self._stack.pop()
        self._implicit_context = self._tree[returning_context_index]
        self._start = old_end

    def _look_back(self, index, offset):
        if index < offset:
            return
        index = index - offset
        return self._tokens[index]

    def _remove_filters(self, expression, tokens):
        intervals = self._detect_filters(tokens)
        for interval in reversed(intervals):
            start, end = interval
            expression = expression[:start] + expression[end:]
        return expression

    def _detect_filters(self, tokens):
        in_filter_context = False
        counter = 0
        intervals = list()
        for token in tokens:
            if not in_filter_context and token['type'] == 'filter':
                in_filter_context = True
                start = token['start']
            elif in_filter_context:
                if token['type'] in {'filter', 'lbracket'}:
                    counter += 1
                elif token['type'] == 'rbracket':
                    if counter == 0:
                        in_filter_context = False
                        end = token['end']
                        intervals.append((start, end))
                    else:
                        counter -= 1
        return intervals

    def _get_shape_dict(self, service, operation):
        try:
            service, operation = (self._get_transformed_names(
                service, operation))
        except InvalidShapeData:
            return None

        try:
            return self._parse_shape(service, operation)
        except ModelLoadingError:
            return None

    def _get_transformed_names(self, service, operation):
        if service == 's3api':
            service = 's3'
        service_data = self.command_table.get(service, None)
        if not service_data:
            raise InvalidShapeData()
        operation = service_data.get_operation_name(operation)
        if not operation:
            raise InvalidShapeData()
        return service, operation

    def _parse_shape(self, service, operation):
        if service != self._service:
            self._service_model = self._load_service_model(service)
            operation_model = (self._load_operation_model(
                self._service_model, operation))
            parsed = self._shape_parser.parse(operation_model.output_shape)
            self._service = service
            self._operation = operation
            return parsed

        if operation != self._operation:
            operation_model = (self._load_operation_model(
                self._service_model, operation))
            parsed = self._shape_parser.parse(operation_model.output_shape)
            self._operation = operation
            return parsed

        return self._shape_dict

    def _load_service_model(self, service_name):
        try:
            service_model = self._session.get_service_model(service_name)
        except UnknownServiceError as e:
            raise ModelLoadingError(str(e))
        return service_model

    def _load_operation_model(self, service_model, operation):
        try:
            operation_model = service_model.operation_model(operation)
        except OperationNotFoundError as e:
            raise ModelLoadingError(str(e))
        return operation_model
コード例 #9
0
    def add_slides(self, with_annotations):
        layer = self._add_layer('Slides')
        doc = ET.parse(os.path.join(self.opts.basedir, 'shapes.svg'))
        slides = {}
        slide_time = IntervalTree()
        for img in doc.iterfind(
                './{http://www.w3.org/2000/svg}image[@class="slide"]'):
            info = SlideInfo(
                id=img.get('id'),
                width=int(img.get('width')),
                height=int(img.get('height')),
                start=round(float(img.get('in')) * Gst.SECOND),
                end=round(float(img.get('out')) * Gst.SECOND),
            )
            slides[info.id] = info
            slide_time.addi(info.start, info.end, info)

            # Don't bother creating an asset for out of range slides
            if info.end < self.start_time or info.start > self.end_time:
                continue

            path = img.get('{http://www.w3.org/1999/xlink}href')
            # If this is a "deskshare" slide, don't show anything
            if path.endswith('/deskshare.png'):
                continue

            asset = self._get_asset(os.path.join(self.opts.basedir, path))
            width, height = self._constrain(
                self._get_dimensions(asset),
                (self.slides_width, self.opts.height))
            self._add_clip(layer, asset, info.start, 0, info.end - info.start,
                           0, 0, width, height)

        # If we're not processing annotations, then we're done.
        if not with_annotations:
            return

        cursor_layer = self._add_layer('Cursor')
        # Move above the slides layer
        self.timeline.move_layer(cursor_layer, cursor_layer.get_priority() - 1)
        dot = self._get_asset('dot.png')
        dot_width, dot_height = self._get_dimensions(dot)
        cursor_doc = ET.parse(os.path.join(self.opts.basedir, 'cursor.xml'))
        events = []
        for event in cursor_doc.iterfind('./event'):
            x, y = event.find('./cursor').text.split()
            start = round(float(event.attrib['timestamp']) * Gst.SECOND)
            events.append(CursorEvent(float(x), float(y), start))

        for i, pos in enumerate(events):
            # negative positions are used to indicate that no cursor
            # should be displayed.
            if pos.x < 0 and pos.y < 0:
                continue

            # Show cursor until next event or if it is the last event,
            # the end of recording.
            if i + 1 < len(events):
                end = events[i + 1].start
            else:
                end = self.end_time

            # Find the width/height of the slide corresponding to this
            # point in time
            info = [i.data for i in slide_time.at(pos.start)][0]
            width, height = self._constrain(
                (info.width, info.height),
                (self.slides_width, self.opts.height))

            self._add_clip(cursor_layer, dot, pos.start, 0, end - pos.start,
                           round(width * pos.x - dot_width / 2),
                           round(height * pos.y - dot_height / 2), dot_width,
                           dot_height)

        layer = self._add_layer('Annotations')
        # Move above the slides layer
        self.timeline.move_layer(layer, layer.get_priority() - 1)
        for canvas in doc.iterfind(
                './{http://www.w3.org/2000/svg}g[@class="canvas"]'):
            info = slides[canvas.get('image')]
            t = IntervalTree()
            for index, shape in enumerate(
                    canvas.iterfind(
                        './{http://www.w3.org/2000/svg}g[@class="shape"]')):
                shape.set('style',
                          shape.get('style').replace('visibility:hidden;', ''))
                timestamp = round(float(shape.get('timestamp')) * Gst.SECOND)
                undo = round(float(shape.get('undo')) * Gst.SECOND)
                if undo < 0:
                    undo = info.end

                # Clip timestamps to slide visibility
                start = min(max(timestamp, info.start), info.end)
                end = min(max(undo, info.start), info.end)

                # Don't bother creating annotations for out of range times
                if end < self.start_time or start > self.end_time:
                    continue

                t.addi(start, end, [(index, shape)])

            t.split_overlaps()
            t.merge_overlaps(strict=True, data_reducer=operator.add)
            for index, interval in enumerate(sorted(t)):
                svg = ET.Element('{http://www.w3.org/2000/svg}svg')
                svg.set('version', '1.1')
                svg.set('width', '{}px'.format(info.width))
                svg.set('height', '{}px'.format(info.height))
                svg.set('viewBox', '0 0 {} {}'.format(info.width, info.height))

                # We want to discard all but the last version of each
                # shape ID, which requires two passes.
                shapes = sorted(interval.data)
                shape_index = {}
                for index, shape in shapes:
                    shape_index[shape.get('shape')] = index
                for index, shape in shapes:
                    if shape_index[shape.get('shape')] != index: continue
                    svg.append(shape)

                path = os.path.join(
                    self.opts.basedir,
                    'annotations-{}-{}.svg'.format(info.id, index))
                with open(path, 'wb') as fp:
                    fp.write(ET.tostring(svg, xml_declaration=True))

                asset = self._get_asset(path)
                width, height = self._constrain(
                    (info.width, info.height),
                    (self.slides_width, self.opts.height))
                self._add_clip(layer, asset, interval.begin, 0,
                               interval.end - interval.begin, 0, 0, width,
                               height)
コード例 #10
0
class IntervalGroup(BaseTree):
    _tree: IntervalTree

    @staticmethod
    def compatible_keys(keys):
        for key in keys:
            if not isinstance(key, tuple):
                return False
            if not len(key) == 2:
                return False
            if not all([isinstance(x, int) for x in key]):
                return False
        return True

    @classmethod
    def from_dict(cls, d):
        ivs = [Interval(*k, v) for k, v in d.items()]
        return cls(IntervalTree(ivs))

    @classmethod
    def from_label_dict(cls, d):
        ivs = [Interval(*map(int, k.split("-")), v) for k, v in d.items()]
        return cls(IntervalTree(ivs))

    def add_group(self, name, group):
        self[name] = group

    def key_to_label(self, key):
        return f"{key[0]}-{key[1]}"

    def label_to_key(self, label):
        return tuple(apply(int, label.split("-")))

    def to_label_dict(self):
        return {f"{iv.begin}-{iv.end}": iv.data for iv in sorted(self._tree)}

    def to_dict(self):
        return {(iv.begin, iv.end): iv.data for iv in sorted(self._tree)}

    def __init__(self, tree=None, *args, **kwargs):
        if tree is None:
            tree = IntervalTree()
        if not isinstance(tree, IntervalTree):
            raise TypeError("tree must be an instance of IntervalTree.")
        self._tree = tree

    def __getitem__(self, key):
        if isinstance(key, str):
            key = self.label_to_key(key)
        if isinstance(key, int):
            return self.value(key)
        elif isinstance(key, tuple) and len(key) == 2:
            return self.overlap_content(*key)
        elif isinstance(key, Iterable):
            return self.values_at(key)
        elif isinstance(key, slice):
            start = key.start or self.start
            stop = key.stop or self.end
            if key.step is None:
                return self.overlap(key.start, key.stop)
            else:
                return self.values_at(range(start, stop, key.step))

    @property
    def start(self):
        return self._tree.begin()

    @property
    def end(self):
        return self._tree.end()

    def __setitem__(self, key, value):
        if isinstance(key, str):
            key = self.label_to_key(key)
        if isinstance(key, slice):
            start, stop, step = key.start, key.stop, key.step
        elif isinstance(key, tuple):
            if len(key) == 2:
                start, stop = key
                step = None
            elif len(key) == 3:
                start, stop, step = key
            else:
                raise ValueError("Setting intervals with tuple must be  \
                            of form (start, end) or (start, end, step)")
        else:
            raise TypeError(
                "Wrong type. Setting intervals can only be done using a \
                            slice or tuple of (start, end) or (start, end, step)"
            )
        if start is None:
            start = self.start
        if stop is None:
            stop = self.end
        if step is None:
            self.set_interval(start, stop, value)
        else:
            indices = list(range(start, stop, step))
            for begin, end, val in zip(indices[:-1], indices[1:], value):
                self.set_interval(begin, end, val)

    def __delitem__(self, key):
        if isinstance(key, str):
            key = self.label_to_key(key)
        if isinstance(key, tuple) and len(key) == 2:
            self._tree.chop(*key)

        if isinstance(key, slice):
            self._tree.chop(key.start, key.end)
        raise TypeError("Must pass a tuple of (begin,end) or slice.")

    def keys(self):
        for iv in sorted(self._tree):
            yield iv.begin, iv.end

    def labels(self):
        return map(self.key_to_label, self.keys())

    def items(self):
        for iv in sorted(self._tree):
            yield (iv.begin, iv.end), iv.data

    def values(self):
        for iv in sorted(self._tree):
            yield iv.data

    def __iter__(self):
        return self.keys()

    def __len__(self):
        return len(self._tree)

    def __bool__(self):
        return bool(len(self._tree))

    def __getstate__(self):
        return tuple(sorted([tuple(iv) for iv in self._tree]))

    def __setstate__(self, d):
        ivs = [Interval(*iv) for iv in d]
        self._tree = IntervalTree(ivs)

    def overlap(self, begin, end):
        hits = sorted(self._tree.overlap(begin, end))
        return [
            Interval(max(iv.begin, begin), min(iv.end, end), iv.data)
            for iv in hits
        ]

    def overlap_content(self, begin, end):
        hits = sorted(self._tree.overlap(begin, end))
        if len(hits) == 1:
            return hits[0].data
        return [hit.data for hit in hits]

    def value(self, index):
        hits = sorted(self._tree.at(index))
        if len(hits) == 1:
            return hits[0].data
        return hits

    def values_at(self, indices):
        return [self.value(i) for i in indices]

    def set_interval(self, begin, end, value):
        self._tree.chop(begin, end)
        self._tree.addi(begin, end, value)

    def to_df(self, title="tree"):
        import pandas as pd
        ivs = []

        for (begin, end), data in self.items():
            if isinstance(data, BaseTree):
                data = float("nan")
            interval = {
                "label": f"{begin}-{end}",
                "begin": begin,
                "parameter": title,
                "mid": (begin + end) / 2,
                "end": end,
                "data": data
            }
            ivs.append(interval)
        return pd.DataFrame(ivs)

    def to_native(self):
        ivs = []
        for (begin, end), data in self.items():
            if isinstance(data, BaseTree):
                iv = Interval(begin, end, data.to_native())
            else:
                iv = Interval(begin, end, data)
            ivs.append(iv)
        return IntervalTree(ivs)

    def explorer(self, title="tree"):
        import panel as pn
        pn.extension()
        from ..visualizations import IntervalTreeExplorer
        return IntervalTreeExplorer(tree=self, label=title)
コード例 #11
0
def generate_scans(
    isotopologue_lib: dict,
    peak_properties: dict,
    interval_tree: IntervalTree,
    fragmentor: AbstractFragmentor,
    noise_injector: AbstractNoiseInjector,
    mzml_params: dict,
):
    """Summary.

    Args:
        isotopologue_lib (TYPE): Description
        peak_properties (TYPE): Description
        fragmentation_function (A): Description
        mzml_params (TYPE): Description
    """
    logger.info("Initialize chimeric spectra counter")
    chimeric_count = 0
    chimeric = Counter()
    logger.info("Start generating scans")
    t0 = time.time()
    gradient_length = mzml_params["gradient_length"]
    ms_rt_diff = mzml_params.get("ms_rt_diff", 0.03)
    t: float = 0

    mol_scan_dict: Dict[str, Dict[str, list]] = {}
    scans: List[Tuple[Scan, List[Scan]]] = []
    # i: int = 0
    spec_id: int = 1
    de_tracker: Dict[str, int] = {}
    de_stats: dict = {}

    mol_scan_dict = {
        mol: {
            "ms1_scans": [],
            "ms2_scans": []
        }
        for mol in isotopologue_lib
    }
    molecules = list(isotopologue_lib.keys())

    progress_bar = tqdm(
        total=gradient_length,
        desc="Generating scans",
        bar_format=
        "{desc}: {percentage:3.0f}%|{bar}| {n:.2f}/{total_fmt} [{elapsed}<{remaining}",
    )
    while t < gradient_length:
        scan_peaks: List[Tuple[float, float]] = []
        scan_peaks = {}
        mol_i = []
        mol_monoisotopic = {}
        candidates = interval_tree.at(t)
        # print(len(candidates))
        for mol in candidates:
            # if len(candidates) > 1:
            mol = mol.data
            mol_plus = f"{mol}"
            mz = np.array(isotopologue_lib[mol]["mz"])
            intensity = np.array(isotopologue_lib[mol]["i"])
            intensity = rescale_intensity(intensity, t, mol, peak_properties,
                                          isotopologue_lib)

            mask = intensity > mzml_params["min_intensity"]
            intensity = intensity[mask]

            # clip max intensity
            intensity = np.clip(intensity,
                                a_min=None,
                                a_max=mzml_params.get("max_intensity", 1e10))

            mz = mz[mask]
            mol_peaks = list(zip(mz, intensity))
            mol_peaks = {
                round(mz, 6): _i
                for mz, _i in list(zip(mz, intensity))
            }
            # !FIXED! multiple molecules which share mz should have summed up intensityies for that shared mzs
            if len(mol_peaks) > 0:
                mol_i.append((mol, mz[0], sum(intensity)))
                # scan_peaks.extend(mol_peaks)
                for mz, intensity in mol_peaks.items():
                    if mz in scan_peaks:
                        scan_peaks[mz] += intensity
                    else:
                        scan_peaks[mz] = intensity
                mol_scan_dict[mol]["ms1_scans"].append(spec_id)
                highest_peak = max(mol_peaks.items(), key=lambda x: x[1])
                mol_monoisotopic[mol] = {
                    "mz": highest_peak[0],
                    "i": highest_peak[1],
                }
        scan_peaks = sorted(list(scan_peaks.items()), key=lambda x: x[1])
        if len(scan_peaks) > 0:
            mz, inten = zip(*scan_peaks)
        else:
            mz, inten = [], []

        s = Scan({
            "mz": np.array(mz),
            "i": np.array(inten),
            "id": spec_id,
            "rt": t,
            "ms_level": 1,
        })
        prec_scan_id = spec_id
        spec_id += 1

        sorting = s.mz.argsort()
        s.mz = s.mz[sorting]
        s.i = s.i[sorting]

        # add noise
        s = noise_injector.inject_noise(s)

        # i += 1
        scans.append((s, []))
        t += ms_rt_diff
        progress_bar.update(ms_rt_diff)

        if t > gradient_length:
            break

        fragment_spec_index = 0
        max_ms2_spectra = mzml_params.get("max_ms2_spectra", 10)
        if len(mol_i) < max_ms2_spectra:
            max_ms2_spectra = len(mol_i)
        ms2_scan = None
        mol_i = sorted(mol_i, key=lambda x: x[2], reverse=True)
        logger.debug(f"All molecules eluting: {len(mol_i)}")
        logger.debug(f"currently # fragment spectra {len(scans[-1][1])}")

        mol_i = [
            mol for mol in mol_i if (de_tracker.get(mol[0], None) is None) or
            (t - de_tracker[mol[0]]) > mzml_params["dynamic_exclusion"]
        ]
        logger.debug(f"All molecules eluting after DE filtering: {len(mol_i)}")
        while len(scans[-1][1]) != max_ms2_spectra:
            logger.debug(f"Frag spec index {fragment_spec_index}")
            if fragment_spec_index > len(mol_i) - 1:
                # we evaluated fragmentation for every potential mol
                # and all will be skipped
                logger.debug(f"All possible mol are skipped due to DE")
                break
            mol = mol_i[fragment_spec_index][0]
            _mz = mol_i[fragment_spec_index][1]
            _intensity = mol_i[fragment_spec_index][2]
            fragment_spec_index += 1
            mol_plus = f"{mol}"
            all_mols_in_mz_and_rt_window = [
                mol.data for mol in candidates
                if (abs(isotopologue_lib[mol.data]["mz"][0] -
                        _mz) < mzml_params["isolation_window_width"])
            ]
            if len(all_mols_in_mz_and_rt_window) > 1:
                chimeric_count += 1
                chimeric[len(all_mols_in_mz_and_rt_window)] += 1
            if mol is None:
                # dont add empty MS2 scans but have just a much scans as precursors
                breakpoint()
                ms2_scan = Scan({
                    "mz": np.array([]),
                    "i": np.array([]),
                    "rt": t,
                    "id": spec_id,
                    "precursor_mz": 0,
                    "precursor_i": 0,
                    "precursor_charge": 1,
                    "precursor_scan_id": prec_scan_id,
                    "ms_level": 2,
                })
                spec_id += 1
                t += ms_rt_diff
                progress_bar.update(ms_rt_diff)

                if t > gradient_length:
                    break
            elif (peak_properties[mol_plus]["scan_start_time"] <= t) and (
                (peak_properties[mol_plus]["scan_start_time"] +
                 peak_properties[mol_plus]["peak_width"]) >= t):
                # fragment all molecules in isolation and rt window
                # check if molecule needs to be fragmented according to dynamic_exclusion rule
                if (de_tracker.get(mol, None) is None or
                    (t - de_tracker[mol]) > mzml_params["dynamic_exclusion"]):
                    logger.debug("Generate Fragment spec")
                    de_tracker[mol] = t
                    if mol not in de_stats:
                        de_stats[mol] = {"frag_events": 0, "frag_spec_ids": []}
                    de_stats[mol]["frag_events"] += 1
                    de_stats[mol]["frag_spec_ids"].append(spec_id)
                    peaks = fragmentor.fragment(all_mols_in_mz_and_rt_window)
                    frag_mz = peaks[:, 0]
                    frag_i = peaks[:, 1]
                    ms2_scan = Scan({
                        "mz":
                        frag_mz,
                        "i":
                        frag_i,
                        "rt":
                        t,
                        "id":
                        spec_id,
                        "precursor_mz":
                        mol_monoisotopic[mol]["mz"],
                        "precursor_i":
                        mol_monoisotopic[mol]["i"],
                        "precursor_charge":
                        peak_properties[mol]["charge"],
                        "precursor_scan_id":
                        prec_scan_id,
                        "ms_level":
                        2,
                    })
                    spec_id += 1
                    ms2_scan.i = rescale_intensity(ms2_scan.i, t, mol,
                                                   peak_properties,
                                                   isotopologue_lib)
                    ms2_scan = noise_injector.inject_noise(ms2_scan)
                    ms2_scan.i *= 0.5
                else:
                    logger.debug(f"Skip {mol} due to dynamic exclusion")
                    continue
                t += ms_rt_diff
                progress_bar.update(ms_rt_diff)
                if t > gradient_length:
                    break
            else:
                logger.debug(f"Skip {mol} since not in RT window")
                continue
            if mol is not None:
                mol_scan_dict[mol]["ms2_scans"].append(spec_id)
            if ms2_scan is None:
                # there are molecules in mol_i
                # however all molecules are excluded from fragmentation_function
                # => Don't do a scan and break the while loop
                # => We should rather continue and try to fragment the next mol!
                logger.debug(
                    f"Continue and fragment next mol since MS2 scan is None")
                continue
            if (
                    len(ms2_scan.mz) > -1
            ):  # TODO -1 to also add empty ms2 specs; 0 breaks tests currently ....
                sorting = ms2_scan.mz.argsort()
                ms2_scan.mz = ms2_scan.mz[sorting]
                ms2_scan.i = ms2_scan.i[sorting]
                logger.debug(f"Append MS2 scan with {mol}")
                scans[-1][1].append(ms2_scan)
    progress_bar.close()
    t1 = time.time()
    logger.info("Finished generating scans")
    logger.info(f"Generating scans took {t1-t0:.2f} seconds")
    logger.info(f"Found {chimeric_count} chimeric scans")

    return scans, mol_scan_dict
コード例 #12
0
class PointerExtractorPlugin:
    name = 'Pointer Extractor'
    description = 'Extracts marked pointers from .incbins'
    hidden = True

    def __init__(self, api: PluginApi) -> None:
        self.api = api
        self.incbins = None

    def load(self) -> None:
        self.action_parse_incbins = self.api.register_menu_entry(
            'Parse files for .incbins', self.slot_parse_incbins)
        self.action_find_pointers = self.api.register_menu_entry(
            'Find unextracted pointers', self.slot_find_pointers)

    def unload(self) -> None:
        self.api.remove_menu_entry(self.action_parse_incbins)
        self.api.remove_menu_entry(self.action_find_pointers)

    def slot_parse_incbins(self) -> None:
        incbins = []

        assembly_extensions = ['.inc', '.s']
        for root, dirs, files in os.walk(settings.get_repo_location()):
            for file in files:
                filename, file_extension = os.path.splitext(file)
                if file_extension in assembly_extensions:
                    incbins.extend(self.find_incbins(os.path.join(root, file)))
        self.incbins = IntervalTree(incbins)
        self.api.show_message('Pointer Extractor',
                              f'{len(incbins)} .incbins found')

    def find_incbins(self, path: str) -> List[Interval]:
        incbins = []
        with open(path, 'r') as file:
            for line in file:
                line = line.strip()
                if line.startswith('.incbin "baserom.gba"'):
                    arr = line.split(',')
                    if len(arr) == 3:
                        addr = int(arr[1], 16)
                        length = int(arr[2], 16)
                        incbin = Interval(addr, addr + length, path)
                        incbins.append(incbin)
                    else:
                        print(f'Invalid incbin: {line}')
        return incbins

    def slot_find_pointers(self) -> None:
        if self.incbins is None:
            #self.api.show_error('Pointer Extractor', 'Need to parse .incbins first')
            #return
            self.slot_parse_incbins()

        symbol_database = get_symbol_database()

        if not symbol_database.are_symbols_loaded(RomVariant.USA):
            self.api.show_error('Pointer Extractor',
                                'Symbols for USA rom need to be loaded first')
            return

        symbols = symbol_database.get_symbols(RomVariant.USA)

        pointers = get_pointer_database().get_pointers(RomVariant.USA)

        to_extract: Dict[str, SortedKeyList[Pointer]] = {}

        for pointer in pointers:
            found = self.incbins.at(pointer.address)
            if len(found) == 1:
                interval = found.pop()
                file = interval.data

                if not file in to_extract:
                    to_extract[file] = SortedKeyList(key=lambda x: x.address)

                to_extract[file].add(pointer)
#                print(hex(pointer.address))
#print(found.pop())
            elif len(found) > 1:
                print(
                    f'Found {len(found)} incbins for address {pointer.address}'
                )

        # Count unextracted pointers
        count = 0
        for file in to_extract:
            print(f'{file}: {len(to_extract[file])}')
            count += len(to_extract[file])

        self.api.show_message('Pointer Extractor',
                              f'{count} unextracted pointers found')
        print(count)

        # Find symbols that unextracted pointers point to
        missing_labels = {}
        count = 0
        for file in to_extract:
            for pointer in to_extract[file]:

                symbol = symbols.get_symbol_at(pointer.points_to - ROM_OFFSET)
                offset = pointer.points_to - ROM_OFFSET - symbol.address
                if offset > 1:  # Offset 1 is ok for function pointers
                    print(pointer)
                    if symbol.file not in missing_labels:
                        missing_labels[symbol.file] = SortedKeyList(
                            key=lambda x: x.address)
                    # Insert Missing label if there is not already one
                    label = MissingLabel(pointer.points_to - ROM_OFFSET,
                                         symbol.name, offset, symbol.file)
                    if label not in missing_labels[symbol.file]:
                        missing_labels[symbol.file].add(label)
                        count += 1
                    continue

        print(f'{count} missing labels')
        for file in missing_labels:
            print(f'{file}: {len(missing_labels[file])}')

        # Insert labels for incbins
        for path in missing_labels:
            output_lines = []
            labels = missing_labels[path]
            next_label = labels.pop(0)

            # Try to find source assembly file
            asm_path = os.path.join(settings.get_repo_location(),
                                    path.replace('.o', '.s'))
            if not os.path.isfile(asm_path):
                print(f'Cannot insert labels in {path}')
                print(missing_labels[path])
                continue

            with open(asm_path, 'r') as file:
                for line in file:
                    if next_label is not None and line.strip().startswith(
                            '.incbin "baserom.gba"'):
                        arr = line.split(',')
                        if len(arr) == 3:
                            addr = int(arr[1], 16)
                            length = int(arr[2], 16)

                            while next_label is not None and next_label.address < addr:
                                print(f'Cannot insert {next_label}')
                                if len(labels) == 0:  # Extracted all labels
                                    next_label = None
                                    break
                                next_label = labels.pop(0)
                                continue

                            while next_label is not None and next_label.address >= addr and next_label.address < addr + length:
                                # Calculate new incbins
                                prev_addr = addr
                                prev_length = next_label.address - addr
                                after_addr = next_label.address
                                after_length = addr + length - after_addr

                                if prev_length > 0:
                                    # print the incbin
                                    output_lines.append(
                                        incbin_line(prev_addr, prev_length))

                                # Print the label
                                label_addr = '{0:#010x}'.format(
                                    next_label.address +
                                    ROM_OFFSET).upper().replace('0X', '')
                                output_lines.append(
                                    f'gUnk_{label_addr}:: @ {label_addr}\n')

                                addr = after_addr
                                length = after_length

                                if len(labels) == 0:  # Extracted all labels
                                    next_label = None
                                    break
                                next_label = labels.pop(0)
                                continue

                            if length > 0:
                                output_lines.append(incbin_line(addr, length))
                            continue
                    output_lines.append(line)

            while next_label is not None:

                # tmp: print label for script
                label_addr = '{0:#010x}'.format(next_label.address +
                                                ROM_OFFSET).upper().replace(
                                                    '0X', '')
                print(f'SCRIPT_START script_{label_addr}')
                print(f'at {next_label.symbol}')

                #print(f'Could not insert {next_label}')
                if len(labels) == 0:  # Extracted all labels
                    next_label = None
                    break
                next_label = labels.pop(0)

            print(f'Write {asm_path}')
            print(next_label)
            with open(asm_path, 'w') as file:
                file.writelines(output_lines)

        print('Extracting pointers')

        # Extract pointers
        for path in to_extract:
            output_lines = []
            pointers = to_extract[path]
            next_pointer = pointers.pop(0)
            with open(path, 'r') as file:
                for line in file:
                    if next_pointer is not None and line.strip().startswith(
                            '.incbin "baserom.gba"'):
                        arr = line.split(',')
                        if len(arr) == 3:
                            addr = int(arr[1], 16)
                            length = int(arr[2], 16)

                            while next_pointer.address >= addr and next_pointer.address < addr + length:
                                # Pointer is in this incbin
                                symbol = symbols.get_symbol_at(
                                    next_pointer.points_to - ROM_OFFSET)
                                offset = next_pointer.points_to - ROM_OFFSET - symbol.address
                                if offset > 1:
                                    # Missing label
                                    if len(pointers
                                           ) == 0:  # Extracted all pointers
                                        next_pointer = None
                                        break
                                    next_pointer = pointers.pop(0)
                                    continue

                                # Calculate new incbins
                                prev_addr = addr
                                prev_length = next_pointer.address - addr
                                after_addr = next_pointer.address + 4
                                after_length = addr + length - after_addr
                                if after_length < 0:
                                    message = f'Pointer at {hex(next_pointer.address)} crosses over from incbin at {hex(addr)}'
                                    print(path)
                                    print(after_length)
                                    print(message)
                                    self.api.show_error(
                                        'Pointer Extractor', message)
                                    return

                                if prev_length > 0:
                                    # print the incbin
                                    output_lines.append(
                                        incbin_line(prev_addr, prev_length))

                                # Print the pointer
                                output_lines.append(
                                    f'\t.4byte {symbol.name}\n')

                                addr = after_addr
                                length = after_length

                                if len(pointers
                                       ) == 0:  # Extracted all pointers
                                    next_pointer = None
                                    break
                                next_pointer = pointers.pop(0)

                            if length > 0:
                                output_lines.append(incbin_line(addr, length))
                            continue

                    output_lines.append(line)

            with open(path, 'w') as file:
                file.writelines(output_lines)
            #print(''.join(output_lines))
        self.api.show_message('Pointer Extractor', f'Done extracting pointers')
コード例 #13
0
ファイル: manager.py プロジェクト: nicHoch/STIXCore
class SOOPManager():
    """Manages LTP files provided by GFTS"""

    SOOP_FILE_FILTER = "SSTX_observation_timeline_export_*.json"
    SOOP_FILE_REGEX = re.compile(
        r'.*SSTX_observation_timeline_export_.*.json$')

    def __init__(self, data_root):
        """Create the manager for a given data path root.

        All existing files will be index and the dir is observed.

        Parameters
        ----------
        data_root : `str` | `pathlib.Path`
            Path to the directory with all LTP files.
        """
        self.filecounter = 0
        self.data_root = data_root

    @property
    def data_root(self):
        """Get the data path root directory.

        Returns
        -------
        `pathlib.Path`
            path of the root directory
        """
        return self._data_root

    @data_root.setter
    def data_root(self, value):
        """Set the data path root.

        Parameters
        ----------
        data_root : `str` or `pathlib.Path`
            Path to the directory with all LTP files.
        """
        path = Path(value)
        if not path.exists():
            raise ValueError(f'path not found: {value}')

        self._data_root = path

        self.soops = IntervalTree()
        self.observations = IntervalTree()

        files = sorted(list(self._data_root.glob(
            SOOPManager.SOOP_FILE_FILTER)),
                       key=os.path.basename)

        if len(files) == 0:
            raise ValueError(
                f'No current SOOP files found at: {self._data_root}')

        for sfile in files:
            self.add_soop_file_to_index(sfile)

    def find_soops(self, *, start, end=None):
        """Search for all SOOPs in the index.

        Parameters
        ----------
        start : `datetime`
            start time to look for overlapping SOOPs in utc time
        end : `datetime`, optional
            end time to look for overlapping SOOPs in utc time, by default None ()

        Returns
        -------
        `list`
            list of found `SOOP` in all indexed LTP overlapping the given timeperiod/point
        """
        intervals = set()
        if end is None:
            intervals = self.soops.at(start)
        else:
            intervals = self.soops.overlap(start, end)

        return list([o.data for o in intervals])

    def find_observations(self,
                          *,
                          start,
                          end=None,
                          otype=SoopObservationType.ALL):
        """Search for all observations in the index.

        Parameters
        ----------
        start : `datetime`
            start time to look for overlapping observations in utc time
        end : `datetime`, optional
            end time to look for overlapping observations in utc time, by default None ()
        otype : `SoopObservationType`, optional
            filter for specific type, by default SoopObservationType.ALL

        Returns
        -------
        `list`
            list of found `SOOPObservation` in all indexed LTP overlapping the given
            timeperiod/point and matching the SoopObservationType.
        """
        intervals = set()
        if end is None:
            intervals = self.observations.at(start)
        else:
            intervals = self.observations.overlap(start, end)

        if len(intervals) > 0 and otype != SoopObservationType.ALL:
            intervals = set([o for o in intervals if o.data.type == otype])

        return list([o.data for o in intervals])

    def get_keywords(self, *, start, end=None, otype=SoopObservationType.ALL):
        """Searches for corresponding entries (SOOPs and Observations) in the index LTPs.

        Based on all found entries for the filter parameters a list of
        HeaderKeyword is generated combining all avaliable information.

        Parameters
        ----------
        start : `datetime`
            start time to look for overlapping observations and SOOPs in utc time
        end : `datetime`, optional
            end time to look for overlapping observations and SOOPs in utc time, by default None ()
        otype : `SoopObservationType`, optional
            filter for specific type of observations, by default SoopObservationType.ALL

        Returns
        -------
        `list`
            A list of `HeaderKeyword`

        Raises
        ------
        ValueError
            if no SOOPs or Observations where found in the index LTPs for the given filter settings.
        """
        kwset = KeywordSet()

        soops = self.find_soops(start=start, end=end)
        if len(soops) == 0:
            warnings.warn(f"No soops found for time: {start} - {end}",
                          UserWarning)
        for soop in soops:
            kwset.append(soop.to_fits_keywords())

        obss = self.find_observations(start=start, end=end, otype=otype)
        if len(obss) == 0:
            warnings.warn(
                f"No observations found for time: {start} - {end} : {otype}",
                UserWarning)
        for obs in obss:
            kwset.append(obs.to_fits_keywords())

        return kwset.to_list()

    def add_soop_file_to_index(self, path):
        logger.info(f"Read SOOP file: {path}")
        with open(path) as f:
            ltp_data = json.load(f)
            for jsond in ltp_data["soops"]:
                soop = SOOP(jsond)
                self.soops.addi(soop.startDate, soop.endDate, soop)
            for jsond in ltp_data["observations"]:
                obs = SoopObservation(jsond)
                self.observations.addi(obs.startDate, obs.endDate, obs)

            self.filecounter += 1
コード例 #14
0
class BoxHolder():
    """
    A class to allow quick lookup of boxes (e.g. exclusion items,
    targets, etc). Creates an interval tree on mz as this is likely to
    narrow things down quicker. Also has a method for returning an rt
    interval tree for a particular mz and an mz interval tree for
    a particular rt
    """

    def __init__(self):
        self.boxes_mz = IntervalTree()
        self.boxes_rt = IntervalTree()

    def add_box(self, box):
        """
        Add a box to the IntervalTree
        """
        mz_from = box.from_mz
        mz_to = box.to_mz
        rt_from = box.from_rt
        rt_to = box.to_rt
        self.boxes_mz.addi(mz_from, mz_to, box)
        self.boxes_rt.addi(rt_from, rt_to, box)

    def check_point(self, mz, rt):
        """
        Find the boxes that match this mz and rt value
        """
        regions = self.boxes_mz.at(mz)
        hits = set()
        for r in regions:
            if r.data.rt_match(rt):
                hits.add(r.data)
        return hits

    # FIXME: this produces a different result from check_point, do not use
    # def check_point_2(self, mz, rt):
    #     """
    #     An alternative method that searches both trees
    #     Might be faster if there are lots of rt ranges that
    #     can map to a particular mz value
    #     """
    #     mz_regions = self.boxes_mz.at(mz)
    #     rt_regions = self.boxes_rt.at(rt)
    #     inter = mz_regions.intersection(rt_regions)
    #     return [r.data for r in inter]

    def is_in_box(self, mz, rt):
        """
        Check if this mz and rt is in *any* box
        """
        hits = self.check_point(mz, rt)
        if len(hits) > 0:
            return True
        else:
            return False

    def is_in_box_mz(self, mz):
        """
        Check if an mz value is in any box
        """
        regions = self.boxes_mz.at(mz)
        if len(regions) > 0:
            return True
        else:
            return False

    def is_in_box_rt(self, rt):
        """
        Check if an rt value is in any box
        """
        regions = self.boxes_rt.at(rt)
        if len(regions) > 0:
            return True
        else:
            return False

    def get_subset_rt(self, rt):
        """
        Create an interval tree based upon mz for all boxes active at rt
        """
        regions = self.boxes_rt.at(rt)
        it = BoxHolder()
        for r in regions:
            box = r.data
            it.add_box(box)
        return it

    def get_subset_mz(self, mz):
        """
        Create an interval tree based upon rt fro all boxes active at mz
        """
        regions = self.boxes_mz.at(mz)
        it = BoxHolder()
        for r in regions:
            box = r.data
            it.add_box(box)
        return it