def _extract_sequence_wildcards(self, operands,
                                 constraints):
     pattern_set = Multiset()
     pattern_vars = dict()
     for operand in operands:
         if not self._is_sequence_wildcard(operand):
             actual_constraints = [c for c in constraints if contains_variables_from_set(operand, c.variables)]
             pattern = Pattern(operand, *actual_constraints)
             index = None
             for i, (p, _, _) in enumerate(self.automaton.patterns):
                 if pattern == p:
                     index = i
                     break
             else:
                 index = self.automaton._internal_add(pattern, None, {})
             pattern_set.add(index)
         else:
             varname = getattr(operand, u'variable_name', None)
             if varname is None:
                 if varname in pattern_vars:
                     (_, _, min_count), _ = pattern_vars[varname]
                 else:
                     min_count = 0
                 pattern_vars[varname] = (VariableWithCount(varname, 1, operand.min_count + min_count), False)
             else:
                 if varname in pattern_vars:
                     (_, count, _), wrap = pattern_vars[varname]
                 else:
                     count = 0
                     wrap = operand.fixed_size and self.associative
                 pattern_vars[varname] = (VariableWithCount(varname, count + 1, operand.min_count), wrap)
     return pattern_set, pattern_vars
Example #2
0
 def test_add_twice_remove_twice(self):
     ms = Multiset(10)
     ms.add('a')
     ms.add('a')
     ms.remove('a')
     ms.remove('a')
     self.assertTrue('a' not in ms)
def guess_input(raw):
    """Prints summary information about the input"""
    lines = raw.splitlines()
    print(f"# lines: {len(lines)}")
    line_lengths = set(filter(lambda x: x > 0, [len(line) for line in lines]))
    print(f"Line length range: {min(line_lengths)} to {max(line_lengths)}")
    print(f"# chars: {len(raw)}")
    double_newline = len(raw.split("\n\n")) - 1
    print(f"# double newlines: {double_newline}")
    whitespace, tabs = False, False
    seen = []
    for ch in raw:
        if ch == '\n':
            continue
        if ch == '\t':
            tabs = True
        elif ch.isspace():
            whitespace = True
        elif ch not in seen:
            seen.append(ch)
    print(f"Contains tabs: {tabs}")
    print(f"Contains whitespace: {whitespace}")
    print(f"Chars: {''.join(sorted(seen))}")
    ints = list(extract_ints(raw, negative=True))
    print(f"# Ints: {len(ints)}")
    if len(ints) > 0:
        print(f"Int range: {min(ints)} to {max(ints)}")
    ms = Multiset()
    for word in raw.split():
        if word.isnumeric():
            continue
        ms.add(word)
    common = sorted(ms.items(), key=by_index(1), reverse=True)
    print(f"Most common words: {common}")
    print()
Example #4
0
def __eval_number_of(element: Element, *args, **kwargs) -> Multiset:
    if get_tag(element) != 'numberof':
        raise Exception('Element is not a number of operator!')
    num = eval_term(element[0], *args, **kwargs)
    if not isinstance(num, (int, list, str, tuple)):
        raise Exception('The first child must either be a number, or a color!')
    if isinstance(num, int):
        color = eval_term(element[1], *args, **kwargs)
    else:
        color = num
        num = 1
    if not isinstance(color, (str, list, tuple)):
        raise Exception('The second child of the number of operator must evaluate to a color')

    if isinstance(color, str):
        return Multiset({color: num})

    if isinstance(color, tuple):
        for c in list(color):
            if isinstance(c, list):
                ms = Multiset()
                for p in product(*color):
                    ms.add(p, num)
                return ms
        return Multiset({color: num})

    ms = Multiset()
    for col in color:
        ms.add(col, num)

    return ms
Example #5
0
 def match(self, subjects: Sequence[Expression], substitution: Substitution) -> Iterator[Tuple[int, Substitution]]:
     subject_ids = Multiset()
     pattern_ids = Multiset()
     if self.max_optional_count > 0:
         subject_id, subject_pattern_ids = self.subjects[None]
         subject_ids.add(subject_id)
         for _ in range(self.max_optional_count):
             pattern_ids.update(subject_pattern_ids)
     for subject in op_iter(subjects):
         subject_id, subject_pattern_ids = self.subjects[subject]
         subject_ids.add(subject_id)
         pattern_ids.update(subject_pattern_ids)
     for pattern_index, pattern_set, pattern_vars in self.patterns.values():
         if pattern_set:
             if not pattern_set <= pattern_ids:
                 continue
             bipartite_match_iter = self._match_with_bipartite(subject_ids, pattern_set, substitution)
             for bipartite_substitution, matched_subjects in bipartite_match_iter:
                 ids = subject_ids - matched_subjects
                 remaining = Multiset(self.subjects_by_id[id] for id in ids if self.subjects_by_id[id] is not None)
                 if pattern_vars:
                     sequence_var_iter = self._match_sequence_variables(
                         remaining, pattern_vars, bipartite_substitution
                     )
                     for result_substitution in sequence_var_iter:
                         yield pattern_index, result_substitution
                 elif len(remaining) == 0:
                     yield pattern_index, bipartite_substitution
         elif pattern_vars:
             sequence_var_iter = self._match_sequence_variables(Multiset(op_iter(subjects)), pattern_vars, substitution)
             for variable_substitution in sequence_var_iter:
                 yield pattern_index, variable_substitution
         elif op_len(subjects) == 0:
             yield pattern_index, substitution
 def match(self, subjects, substitution):
     subject_ids = Multiset()
     pattern_ids = Multiset()
     for subject in subjects:
         subject_id, subject_pattern_ids = self.subjects[subject]
         subject_ids.add(subject_id)
         pattern_ids.update(subject_pattern_ids)
     for pattern_index, pattern_set, pattern_vars in self.patterns.values():
         if pattern_set:
             if not pattern_set <= pattern_ids:
                 continue
             bipartite_match_iter = self._match_with_bipartite(subject_ids, pattern_set, substitution)
             for bipartite_substitution, matched_subjects in bipartite_match_iter:
                 if pattern_vars:
                     ids = subject_ids - matched_subjects
                     remaining = Multiset(self.subjects[id] for id in ids)  # pylint: disable=not-an-iterable
                     sequence_var_iter = self._match_sequence_variables(
                         remaining, pattern_vars, bipartite_substitution
                     )
                     for result_substitution in sequence_var_iter:
                         yield pattern_index, result_substitution
                 elif len(subjects) == len(pattern_set):
                     yield pattern_index, bipartite_substitution
         elif pattern_vars:
             sequence_var_iter = self._match_sequence_variables(Multiset(subjects), pattern_vars, substitution)
             for variable_substitution in sequence_var_iter:
                 yield pattern_index, variable_substitution
         elif len(subjects) == 0:
             yield pattern_index, substitution
def multiply_units(*args):
    this_units = Multiset()
    for each_unit in args:
        if get_unit_type(each_unit) == 'multiply':
            this_units.update(get_complexunit_set(each_unit))
        else:
            this_units.add(each_unit)
    if len(this_units) == 1:
        return list(this_units)[0]
    else:
        return ('multiply', FrozenMultiset(this_units))
Example #8
0
 def test_counts(self):
     ms = Multiset(10)
     self.assertEqual(ms['a'], 0)
     ms.add('a')
     self.assertEqual(ms['a'], 1)
     ms.add('a')
     self.assertEqual(ms['a'], 2)
     ms.remove('a')
     self.assertEqual(ms['a'], 1)
     ms.remove('a')
     self.assertEqual(ms['a'], 0)
Example #9
0
 def test_add_remove_many(self):
     ms = Multiset(200)
     for i in range(100):
         ms.add('a')
         ms.add('b')
     self.assertTrue('a' in ms)
     for i in range(100):
         ms.remove('a')
     for i in range(100):
         ms.remove('b')
     self.assertTrue('a' not in ms)
Example #10
0
    def test_sample_unbalance(self):
        n = 10000
        ms = Multiset(n)
        for i in range(99):
            ms.add('a')
        ms.add('b')

        counts = Counter()
        for i in range(1000000):
            counts[ms.sample()] += 1
        self.assertAlmostEqual(counts['b'] / counts['a'], 0.01, 3)
Example #11
0
def generate_multiset(multiset_len, n_elements):

    multiset = Multiset()

    fill = multiset_len
    while fill > 0:
        random_el = random.randint(0, n_elements)
        #multiplicity_el = random.randint(1, 1 + int(multiset_len / 1000))
        multiplicity_el = 1
        multiset.add(random_el, multiplicity_el)
        fill -= multiplicity_el

    return multiset
Example #12
0
 def _extract_sequence_wildcards(self, operands: Iterable[Expression],
                                 constraints) -> Tuple[MultisetOfInt, Dict[str, Tuple[VariableWithCount, bool]]]:
     pattern_set = Multiset()
     pattern_vars = dict()
     opt_count = 0
     for operand in op_iter(operands):
         if isinstance(operand, Wildcard) and operand.optional is not None:
             opt_count += 1
         if not self._is_sequence_wildcard(operand):
             actual_constraints = [c for c in constraints if contains_variables_from_set(operand, c.variables)]
             pattern = Pattern(operand, *actual_constraints)
             index = None
             for i, (p, _, _) in enumerate(self.automaton.patterns):
                 if pattern == p:
                     index = i
                     break
             else:
                 vnames = set(e.variable_name for e in preorder_iter(pattern.expression) if hasattr(e, 'variable_name') and e.variable_name is not None)
                 renaming = {n: n for n in vnames}
                 index = self.automaton._internal_add(pattern, None, renaming)
                 if is_anonymous(pattern.expression):
                     self.anonymous_patterns.add(index)
             pattern_set.add(index)
         else:
             varname = getattr(operand, 'variable_name', None)
             if varname is None:
                 if varname in pattern_vars:
                     (_, _, min_count, _), _ = pattern_vars[varname]
                 else:
                     min_count = 0
                 pattern_vars[varname] = (VariableWithCount(varname, 1, operand.min_count + min_count, None), False)
             else:
                 if varname in pattern_vars:
                     (_, count, _, _), wrap = pattern_vars[varname]
                 else:
                     count = 0
                     wrap = operand.fixed_size and self.associative
                 pattern_vars[varname] = (
                     VariableWithCount(varname, count + 1, operand.min_count, operand.optional), wrap
                 )
     if opt_count > self.max_optional_count:
         self.max_optional_count = opt_count
     return pattern_set, pattern_vars
def test_add():
    m = Multiset('aab')
    assert len(m) == 3

    with pytest.raises(ValueError):
        m.add('a', 0)

    assert 'c' not in m
    m.add('c')
    assert 'c' in m
    assert m['c'] == 1
    assert len(m) == 4

    assert 'd' not in m
    m.add('d', 42)
    assert 'd' in m
    assert m['d'] == 42
    assert len(m) == 46

    m.add('c', 2)
    assert m['c'] == 3
    assert len(m) == 48
Example #14
0
 def test_sample_bias(self):
     n = 1000
     ms = Multiset(n)
     n_a, n_b, n_c = 900, 90, 10
     for i in range(n_a):
         ms.add('a')
     for i in range(n_b):
         ms.add('b')
     for i in range(n_c):
         ms.add('c')
     counts = Counter()
     N = 1000000
     for i in range(N):
         x = ms.sample()
         counts[x] += 1
     self.assertAlmostEqual(counts['a'] / N, n_a / n, 2)
     self.assertAlmostEqual(counts['b'] / N, n_b / n, 2)
     self.assertAlmostEqual(counts['c'] / N, n_c / n, 2)
Example #15
0
class DiceRoll:
    ''' The outcome of rolling a collection of indistinguishable fair dice.
        Rolls are immutable so they can be used as dictionary keys.
        Includes methods for creating new rolls by adding randomly rolled
        dice to an existing roll, and creating subrolls by selecting
        certain dice from a roll.
    '''

    def __init__(self, pips, sides):
        ''' Creates a roll with dice showing the numbers in the given iterable
            and the given number of sides.

            pips -- an iterable over integers between 1 and sides (inclusive)
            sides -- a positive integer
        '''
        self._dice = Multiset(sides)
        self._hash = None
        for n in pips:
            self._dice.add(n - 1)


    @staticmethod
    def roll(count, sides):
        ''' Creates and returns a random roll of the given number of dice with
            the given number of sides.

            count -- a nonnegative integer
            sides -- a positive integer
        '''
        result = DiceRoll([], sides)
        result._dice.add_random(count)
        return result


    @staticmethod
    def parse(s, sides):
        ''' Returns a roll containing dice showing the numbers corresponding
            to the digits in the given string.

            s -- a string containing digits in the range 1 through sides
            sided -- a positive integer
        '''
        if sides <= 0:
            raise ValueError("number of sides must be positive: {0}".format(sides))
        pips = []
        for digit in s:
            if not digit.isdigit():
                raise ValueError("invalid digit {0} in {1}".format(digit, s))
            num = int(digit)
            if num < 1 or num > sides:
                raise ValueError("invalid digit {0} in {1} for {2} sides".format(digit, s, sides))
            pips.append(num)
        return DiceRoll(pips, sides)


    def size(self):
        ''' Returns the number of dice in this roll.
        '''
        return self._dice.size()


    def sides(self):
        ''' Returns the number of sides on the dice in this roll.
        '''
        return self._dice.maximum()


    def min_number(self):
        ''' Returns the minimum number showing in this roll.
            If there are no dice in this roll then the
            value returned is larger than the number of sides.
        '''
        return self._dice.min_element() + 1


    def max_number(self):
        ''' Returns the maximum number showing in this roll.
            If there are no dice in this roll then the
            value returned is zero or less.
        '''
        return self._dice.max_element() + 1

    
    def copy(self):
        ''' Returns a copy of this roll.
        '''
        return DiceRoll(self.as_list(), self.sides())

            
    def reroll(self, total):
        ''' Creates and returns a roll containing dice showing the same numbers
            as in this roll, plus additional randomly rolled dice so the total
            dice is as given.

            total -- an integer greater than or equal to the number of dice
                     in this roll
        '''
        if total < self.size():
            raise ValueError("can't reroll to fewer dice: {0} < {1}".format(total, self.size()))
        result = self.copy()
        result._dice.add_random(total - self.size())
        return result


    def add_one(self, num):
        ''' Creates and returns a roll containing the same dice as this one
            plus one showing the given number.

            num -- an integer between 1 and the number of sides on the dice
                   in this roll (inclusive)
        '''
        if num < 1 or num > self.sides():
            raise ValueError("number out of range for {1}-sided dice: {0}".format(num, self.sides()))
        result = self.copy()
        result._dice.add(num - 1)
        return result


    def subroll(self, other):
        ''' Determines if this roll is a subset of the given roll.  One roll is
            a subset of another if the dice in the two rolls have the same number
            of sides and there is a 1-1 mapping from the first roll to dice
            in the second showing the same number.
        
            other -- a Yahtzee roll
        '''
        return self.sides() == other.sides() and self._dice.subset(other._dice)


    def count(self, num):
        ''' Determines how mnay dice are showing the given number.
        
            num -- an integer
        '''
        return self._dice.count(num - 1)


    def total(self):
        ''' Returns the total showing on the dice.
        '''
        # total in the dice + 1 for each to account for 0...5 vs 1...6
        return self._dice.total() + self._dice.size()


    def all_subrolls(self):
        ''' Returns a list containing all the subrolls of this roll.
        '''
        result = []

        # for each possible number, compute the range of how may dice
        # showing that number we can keep.  For example, for four-sided
        # dice [1, 2, 2, 4] we want [0..1, 0..2, 0, 0..1]
        options = []
        for i in range(self.sides()):
            options.append(range(self.count(i + 1) + 1))

        # for each possible combination of how many of each number,
        # create the corresponding roll
        for counts in itertools.product(*options):
            s = []
            for i in range(self.sides()):
                # add counts[i] dice showing i + 1
                for k in range(counts[i]):
                    s.append(i + 1)
            result.append(DiceRoll(s, self.sides()))

        return result


    def as_list(self):
        ''' Returns a list of the numbers showing in this roll.  The
            list returned will be sorted from lowest to highest number showing.
        '''
        return [x + 1 for x in self._dice.as_list()]


    def as_tuple(self):
        ''' Returns a tuple of the numbers showing in this roll.  The tuple
            returned will be sorted from lowest to highest number showing.
        '''
        return tuple([x + 1 for x in self._dice.as_list()])


    def select_all(self, nums, maximum=None):
        ''' Returns the subroll of this roll that contains all occurrences
            of the given numbers up to the given maximum of each.  If the
            maximum is None then there is no limit.

            nums -- a list of integers betwen 1 and the number of sides
                    on the dice in this roll
            maximum -- an integer, or None
        '''
        keep = []
        # for each number in the list, add as many of that number as are
        # in the roll, up to the given maximum
        for n in nums:
            if n < 1 or n > self.sides():
                raise ValueError("value out of range in {0}".format(nums))
            for i in range(self.count(n) if maximum is None else min(maximum, self.count(n))):
                keep.append(n)
        return DiceRoll(keep, self.sides())


    def select_one(self, nums):
        ''' Returns the subroll of this roll that contains one occurrence
            of each of the given numbers that are also in this roll.

            nums -- a list of integers betwen 1 and the number of sides on the
                    dice in this roll
        '''
        keep = []
        # for each number in the list, add one of that number if the roll
        # contains at least one
        for n in nums:
            if n < 1 or n > self.sides():
                raise ValueError("value out of range in {0}".format(nums))
            if self.count(n) > 0:
                keep.append(n)
        return DiceRoll(keep, self.sides())


    def longest_runs(self):
        ''' Returns a list of all the longest consecutive runs in this
            roll.  For example, if this roll is [1 2 4 4 5] then the
            list returned is [[1, 2], [4, 5]].
        '''
        runs = []
        longest = 0
        curr_len = 0
        for i in range(1, self.sides() + 1):
            if self.count(i) > 0:
                curr_len += 1
                if curr_len == longest:
                    runs.append(list(range(i - curr_len + 1, i + 1)))
                elif curr_len > longest:
                    runs = [list(range(i - curr_len + 1, i + 1))]
                    longest = curr_len
            else:
                curr_len = 0
        return runs
        
        
    def __str__(self):
        ''' Returns a string representation of this roll.
        '''
        return str(self.as_list())


    def __repr__(self):
        return self.__str__()


    def __hash__(self):
        if self._hash is None:
            self._hash = (self.sides(), self.as_tuple()).__hash__()
        return self._hash

        
    def __eq__(self, other):
        ''' Determines if this roll is equal to the given other roll.
            Two rolls are equal if their dice show the same numbers
            and have the same number of sides.
        '''
        return self.sides() == other.sides() and self._dice == other._dice


    def __iter__(self):
        ''' Returns an iterator over the numbers showing in this roll.
        '''
        # delegate to the tuple representation
        return self.as_tuple().__iter__()
class Basket(object):
    def __init__(self, name='Basket'):
        self.name = name
        self.combinations = [
            Multiset('BG'),
            Multiset('GG'),
            Multiset('BRR'),
            Multiset('GRR'),
            Multiset('RRRR'),
            Multiset('O')
        ]
        self.bricks = Multiset()
        self.inside = []

    def __str__(self):
        return str(self.inside)

    def __repr__(self):
        return self.__str__()

    def add(self, brick):
        if brick in self.inside:
            print("BRICK {} IS ALREADY IN THE BASKET {}".format(
                brick, self.name))
            return True

        for option in self.combinations:
            if self.bricks.issubset(option):
                remaining = option - self.bricks
                if brick[0] in remaining:
                    self.bricks.add(brick[0])
                    self.inside.append(brick)
                    return True
        return False

    def add_from(self, bricks):
        for brick in bricks:
            self.add(brick)

    def update_state(self, ohter_basket):
        self.bricks = deepcopy(ohter_basket.bricks)

    def remove(self, brick):
        if brick in self.inside:
            self.bricks.remove(brick[0], multiplicity=1)
            self.inside.remove(brick)
            return True
        else:
            return False

    def remove_from(self, bricks):
        failed = []
        for brick in bricks:
            if not self.remove(brick):
                failed.append(brick)
        return failed if failed else True

    def empty(self):
        self.bricks.clear()
        self.inside = []

    def is_full(self):
        for option in self.combinations:
            if self.bricks == option:
                return True
        return False

    def is_empty(self):
        return len(self.bricks) == 0
Example #17
0
 def test_add(self):
     ms = Multiset(10)
     ms.add('a')
     self.assertTrue('a' in ms)
        changer.to_csv('darklyrics-tokens-temp.csv', index=False, sep='♣')

    # load and pre-process the data
    counter = Counter()
    data = []
    with open('darklyrics-tokens-temp.csv', 'r', encoding='utf8') as o:
        o.readline()
        for line in o:
            parts = line.split('♣')
            words_sep = str(parts[5:][0]).split(',')
            clean_parts = [re.sub(rgx, "", i) for i in words_sep]
            counter.update(clean_parts)

            ms = Multiset()
            for word in clean_parts:
                ms.add(word)

            data.append([parts[:5], ms])

    os.remove('darklyrics-tokens-temp.csv')
    print('first part ')
    min_bound = int(1 / 100 * len(data))
    common_words = Multiset()

    count = 0
    for item in counter.items():
        if item[1] >= min_bound:
            # Il 100 serve perchè l'intersection prende il numero di parole minore nel multiset
            # Vogliamo che il numero minore sia il numero di token
            count += 1
            common_words.add(item[0], 100)
Example #19
0
def _match_commutative_operation(subject_operands, pattern, substitution,
                                 constraints, matcher):
    subjects = Multiset(subject_operands)  # type: Multiset
    if not pattern.constant <= subjects:
        return
    subjects -= pattern.constant
    rest_expr = pattern.rest + pattern.syntactic
    needed_length = (pattern.sequence_variable_min_length +
                     pattern.fixed_variable_length + len(rest_expr) +
                     pattern.wildcard_min_length)

    if len(subjects) < needed_length:
        return

    fixed_vars = Multiset(pattern.fixed_variables)  # type: Multiset[str]
    for name, count in pattern.fixed_variables.items():
        if name in substitution:
            replacement = substitution[name]
            if issubclass(pattern.operation,
                          AssociativeOperation) and isinstance(
                              replacement, pattern.operation):
                needed_count = Multiset(substitution[name])  # type: Multiset
            else:
                if not isinstance(replacement, Expression):
                    return
                needed_count = Multiset({replacement: 1})
            if count > 1:
                needed_count *= count
            if not needed_count <= subjects:
                return
            subjects -= needed_count
            del fixed_vars[name]

    factories = [
        _fixed_expr_factory(e, constraints, matcher) for e in rest_expr
    ]

    if not issubclass(pattern.operation, AssociativeOperation):
        for name, count in fixed_vars.items():
            min_count, symbol_type = pattern.fixed_variable_infos[name]
            factory = _fixed_var_iter_factory(name, count, min_count,
                                              symbol_type, constraints)
            factories.append(factory)

        if pattern.wildcard_fixed is True:
            factory = _fixed_var_iter_factory(None, 1,
                                              pattern.wildcard_min_length,
                                              None, constraints)
            factories.append(factory)
    else:
        for name, count in fixed_vars.items():
            min_count, symbol_type = pattern.fixed_variable_infos[name]
            if symbol_type is not None:
                factory = _fixed_var_iter_factory(name, count, min_count,
                                                  symbol_type, constraints)
                factories.append(factory)

    expr_counter = Multiset(subjects)  # type: Multiset

    for rem_expr, substitution in generator_chain((expr_counter, substitution),
                                                  *factories):
        sequence_vars = _variables_with_counts(pattern.sequence_variables,
                                               pattern.sequence_variable_infos)
        if issubclass(pattern.operation, AssociativeOperation):
            sequence_vars += _variables_with_counts(
                fixed_vars, pattern.fixed_variable_infos)
            if pattern.wildcard_fixed is True:
                sequence_vars += (VariableWithCount(
                    None, 1, pattern.wildcard_min_length), )
        if pattern.wildcard_fixed is False:
            sequence_vars += (VariableWithCount(None, 1,
                                                pattern.wildcard_min_length), )

        for sequence_subst in commutative_sequence_variable_partition_iter(
                Multiset(rem_expr), sequence_vars):
            if issubclass(pattern.operation, AssociativeOperation):
                for v in fixed_vars.distinct_elements():
                    if v not in sequence_subst:
                        continue
                    l = pattern.fixed_variable_infos[v].min_count
                    value = cast(Multiset, sequence_subst[v])
                    if len(value) > l:
                        normal = Multiset(list(value)[:l - 1])
                        wrapped = pattern.operation(*(value - normal))
                        normal.add(wrapped)
                        sequence_subst[v] = normal if l > 1 else iter(
                            normal).next()
                    else:
                        assert len(
                            value
                        ) == 1 and l == 1, u"Fixed variables with length != 1 are not supported."
                        sequence_subst[v] = iter(value).next()
            try:
                result = substitution.union(sequence_subst)
            except ValueError:
                pass
            else:
                for i in _check_constraints(result, constraints):
                    yield i
Example #20
0
class BasePool(object):
    def __init__(self, N, food_size):
        self.reaction_computed = Observable()
        self.step_computed = Observable()
        self.generation_computed = Observable()
        self.expressions = Multiset(N)
        self.tmp_removed_expressions = []
        self.food_size = food_size

    def register_step_observer(self, obs):
        self.step_computed.register(obs.on_step_computed)

    def register_reaction_observer(self, obs):
        self.reaction_computed.register(obs.on_reaction_computed)

    def deregister_observers(self):
        self.reaction_computed.deregister_all()
        self.step_computed.deregister_all()

    def pop_reactive(self):
        assert len(self.expressions) > 0
        t = self.expressions.sample()
        self.tmp_remove(t)
        return t

    def rollback(self, t):
        self.tmp_removed_expressions.remove(t)
        self.append(t)

    def get_total_size(self):
        return sum(len(expr) for expr in self.expressions) + \
                sum(len(expr) for expr in self.tmp_removed_expressions)

    def __len__(self):
        return len(self.expressions) + len(self.tmp_removed_expressions)

    def __iter__(self):
        return iter(self.expressions)

    def unique(self):
        return self.expressions.unique()

    def __contains__(self, t):
        return t in self.expressions or self.can_make(t)

    def load(self, fn):
        raise NotImplementedError()

    def append(self, t):
        self.expressions.add(t)

    def tmp_remove(self, t):
        self.tmp_removed_expressions.append(t)
        self.expressions.remove(t)

    def remove(self, t):
        if t in self.tmp_removed_expressions:
            self.tmp_removed_expressions.remove(t)
            return True
        else:
            if t in self.expressions:
                self.expressions.remove(t)
                return True
            else:
                return False
        #if t == Expression.parse('SII'):
        #    print('food# ' ,self.get_multiplicity(t))

    def remove_all(self, ts):
        for t in ts:
            if not self.remove(t):
                return False
        return True

    def apply_reaction(self, reaction):
        if self.has_or_make_reactives(reaction):
            for r in reaction.reactives:
                self.remove(r)
            for p in reaction.products:
                self.append(p)
            self.reaction_computed(self, reaction)
            return True
        else:
            return False

    def has_or_make_reactives(self, reaction):
        reactives = reaction.reactives
        missing = self.count_missing(reactives)
        if any(r not in self and r not in self.tmp_removed_expressions
               for r in reactives):
            assert missing[r] > 0
        #if missing:
        #    print(reaction, {str(k): v for k,v in missing.items()})
        for compound, count in missing.items():
            if count > 0:
                made = self.make(compound, count)
                #NOTE: a compound can be made before noticing than another
                # cannot be made, but it should never happen with these
                # binary reactions.
                if not made:
                    return False
        return True

    def count_missing(self, reactives):
        missing = Counter(self.expressions.count_missing(reactives))
        in_tmp = Counter(self.tmp_removed_expressions)
        return Counter({k: v - in_tmp[k] for k, v in missing.items()})

    def make(self, compound, count):
        for i in range(count):
            if not self.can_make(compound):
                return False
            if self.remove_all(compound.atoms()):
                self.append(compound)
        return True

    def can_make(self, ts):
        if self.food_size is None or len(ts) > self.food_size:
            return False
        return self.expressions.has_all(ts.atoms())

    def get_multiplicity(self, t):
        return self.expressions[t] - self.tmp_removed_expressions.count(t)

    def __getitem__(self, t):
        return self.get_multiplicity(t)

    def __str__(self):
        pool_strs = []
        for k in sorted(set(self.expressions),
                        key=lambda k: self.expressions[k]):
            pool_strs.append(f"{k} {self.expressions[k]}")
        return "\n".join(pool_strs)

    def serializable(self):
        return {x.serializable(): self[x] for x in self.unique()}

    def evolve(self, num_reactions, timeout_time=1):
        for i in range(num_reactions):
            #with timeout(timeout_time):
            self.step()
            self.step_computed(self, i)

    def evolve_generations(self, num_generations):
        tick = 0
        for i in range(num_generations):
            for j in range(len(self)):
                self.step()
                self.step_computed(self, tick)
                tick += 1
            self.generation_computed(i)