Exemple #1
0
    def test_intersection_de_facto_empty(self):
        rng = Range.open(3, 4)
        self.assertEqual(rng, rng.intersection(rng))
        self.assertEqual(Range.open_closed(3, 3),
                         rng.intersection(Range.at_most(3)))
        self.assertEqual(Range.closed_open(4, 4),
                         rng.intersection(Range.at_least(4)))

        with self.assertRaises(ValueError):
            rng.intersection(Range.less_than(3))
        with self.assertRaises(ValueError):
            rng.intersection(Range.greater_than(4))

        rng2 = Range.closed(3, 4)
        self.assertEqual(Range.open_closed(4, 4),
                         rng2.intersection(Range.greater_than(4)))
Exemple #2
0
 def test_greater_than(self):
     rng = Range.greater_than(5)
     self.assertFalse(5 in rng)
     self.assertTrue(6 in rng)
     self.assertTrue(sys.maxsize in rng)
     self.assertTrue(rng.has_lower_bound())
     self.assertEqual(5, rng.lower_endpoint)
     self.assertEqual(BoundType.open(), rng.lower_bound_type)
     self.assert_unbounded_above(rng)
     self.assertFalse(rng.is_empty())
     self.assertEqual("(5..+\u221e)", str(rng))
Exemple #3
0
    def test_encloses_closed(self):
        rng = Range.closed(2, 5)
        self.assertTrue(rng.encloses(rng))
        self.assertTrue(rng.encloses(Range.open(2, 5)))
        self.assertTrue(rng.encloses(Range.open_closed(2, 5)))
        self.assertTrue(rng.encloses(Range.closed_open(2, 5)))
        self.assertTrue(rng.encloses(Range.closed(3, 5)))
        self.assertTrue(rng.encloses(Range.closed(2, 4)))

        self.assertFalse(rng.encloses(Range.open(1, 6)))
        self.assertFalse(rng.encloses(Range.greater_than(3)))
        self.assertFalse(rng.encloses(Range.less_than(3)))
        self.assertFalse(rng.encloses(Range.at_least(3)))
        self.assertFalse(rng.encloses(Range.at_most(3)))
        self.assertFalse(rng.encloses(Range.all()))
Exemple #4
0
    def test_intersection_singleton(self):
        rng = Range.closed(3, 3)
        self.assertEqual(rng, rng.intersection(rng))

        self.assertEqual(rng, rng.intersection(Range.at_most(4)))
        self.assertEqual(rng, rng.intersection(Range.at_most(3)))
        self.assertEqual(rng, rng.intersection(Range.at_least(3)))
        self.assertEqual(rng, rng.intersection(Range.at_least(2)))

        self.assertEqual(Range.closed_open(3, 3),
                         rng.intersection(Range.less_than(3)))
        self.assertEqual(Range.open_closed(3, 3),
                         rng.intersection(Range.greater_than(3)))

        with self.assertRaises(ValueError):
            rng.intersection(Range.at_least(4))
        with self.assertRaises(ValueError):
            rng.intersection(Range.at_most(2))
Exemple #5
0
class TestRangeSet(TestCase):
    """
    Tests for RangeSet

    Derived from Guava's TreeRangeSet tests, which were written by Louis Wasserman and Chris Povrik
    """

    MIN_BOUND = -1
    MAX_BOUND = 1
    BOUND_TYPES = [BoundType.open(), BoundType.closed()]

    QUERY_RANGES: List[Range[int]] = [Range.all()]
    for i in range(MIN_BOUND, MAX_BOUND + 1):
        QUERY_RANGES.extend([
            Range.at_most(i),
            Range.at_least(i),
            Range.less_than(i),
            Range.greater_than(i),
            Range.closed(i, i),
            Range.open_closed(i, i),
            Range.closed_open(i, i),
        ])
        for j in range(i + 1, MAX_BOUND + 1):
            QUERY_RANGES.extend([
                Range.open(i, j),
                Range.open_closed(i, j),
                Range.closed_open(i, j),
                Range.closed(i, j),
            ])

    def test_empty_enclosing(self):
        self._test_encloses(RangeSet.create_mutable())

    def test_empty_intersects(self):
        self._test_intersects(RangeSet.create_mutable())

    def test_all_single_ranges_enclosing(self):
        for query_range in TestRangeSet.QUERY_RANGES:
            self._test_encloses(RangeSet.create_mutable().add(query_range))
        # also test for the complement of empty once complements are implemented

    def test_all_pair_ranges_enclosing(self):
        for query_range_1 in TestRangeSet.QUERY_RANGES:
            for query_range_2 in TestRangeSet.QUERY_RANGES:
                self._test_encloses(RangeSet.create_mutable().add(
                    query_range_1).add(query_range_2))

    def test_intersect_ranges(self):
        range_set = RangeSet.create_mutable()
        range_set.add_all([
            Range.closed(2, 4),
            Range.closed(5, 7),
            Range.closed(10, 12),
            Range.closed(18, 20),
        ])
        self.assertEqual(
            range_set.ranges_overlapping(Range.closed(19, 21)),
            immutableset([Range.closed(18, 20)]),
        )
        self.assertEqual(
            range_set.ranges_overlapping(Range.closed(11, 19)),
            immutableset([Range.closed(10, 12),
                          Range.closed(18, 20)]),
        )
        self.assertEqual(range_set.ranges_overlapping(Range.closed(0, 1)),
                         immutableset())
        self.assertEqual(range_set.ranges_overlapping(Range.closed(21, 23)),
                         immutableset())
        self.assertEqual(range_set.ranges_overlapping(Range.closed(13, 15)),
                         immutableset())
        self.assertEqual(
            range_set.ranges_overlapping(Range.closed(0, 2)),
            immutableset([Range.closed(2, 4)]),
        )
        self.assertEqual(
            range_set.ranges_overlapping(Range.closed(12, 15)),
            immutableset([Range.closed(10, 12)]),
        )
        self.assertEqual(
            range_set.ranges_overlapping(Range.closed(5, 16)),
            immutableset([Range.closed(5, 7),
                          Range.closed(10, 12)]),
        )

    def test_merges_connected_with_overlap(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 4))
        range_set.add(Range.open(2, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed_open(1, 6) in range_set.as_ranges())

    def test_merges_connected_disjoint(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 4))
        range_set.add(Range.open(4, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed_open(1, 6) in range_set.as_ranges())

    def test_ignores_smaller_sharing_no_bound(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 6))
        range_set.add(Range.open(2, 4))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed(1, 6) in range_set.as_ranges())

    def test_ignores_smaller_sharing_lower_bound(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 6))
        range_set.add(Range.closed(1, 4))
        self._test_invariants(range_set)
        self.assertEqual(tuple([Range.closed(1, 6)]),
                         tuple(range_set.as_ranges()))

    def test_ignores_smaller_sharing_upper_bound(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 6))
        range_set.add(Range.closed(3, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed(1, 6) in range_set.as_ranges())

    def test_ignores_equal(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 6))
        range_set.add(Range.closed(1, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed(1, 6) in range_set.as_ranges())

    def test_extend_same_lower_bound(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(1, 4))
        range_set.add(Range.closed(1, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed(1, 6) in range_set.as_ranges())

    def test_extend_same_upper_bound(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(3, 6))
        range_set.add(Range.closed(1, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed(1, 6) in range_set.as_ranges())

    def test_extend_both_directions(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(3, 4))
        range_set.add(Range.closed(1, 6))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed(1, 6) in range_set.as_ranges())

    def test_add_empty(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed_open(3, 3))
        self._test_invariants(range_set)
        self.assertTrue(len(range_set.as_ranges()) == 0)
        self.assertTrue(range_set.is_empty())

    def test_fill_hole_exactly(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed_open(1, 3))
        range_set.add(Range.closed_open(4, 6))
        range_set.add(Range.closed_open(3, 4))
        self._test_invariants(range_set)
        self.assertTrue(Range.closed_open(1, 6) in range_set.as_ranges())

    def test_fill_hole_with_overlap(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed_open(1, 3))
        range_set.add(Range.closed_open(4, 6))
        range_set.add(Range.closed_open(2, 5))
        self._test_invariants(range_set)
        self.assertEqual(tuple([Range.closed_open(1, 6)]),
                         tuple(range_set.as_ranges()))

    def test_add_many_pairs(self):
        for a_low in range(0, 6):
            for a_high in range(0, 6):
                if a_low > a_high:
                    continue
                a_ranges = [
                    Range.closed(a_low, a_high),
                    Range.open_closed(a_low, a_high),
                    Range.closed_open(a_low, a_high),
                ]
                if a_low != a_high:
                    a_ranges.append(Range.open(a_low, a_high))

                for b_low in range(0, 6):
                    for b_high in range(0, 6):
                        if b_low > b_high:
                            continue
                        b_ranges = [
                            Range.closed(b_low, b_high),
                            Range.open_closed(b_low, b_high),
                            Range.closed_open(b_low, b_high),
                        ]
                        if b_low != b_high:
                            b_ranges.append(Range.open(b_low, b_high))
                        for a_range in a_ranges:
                            for b_range in b_ranges:
                                self._pair_test(a_range, b_range)

    def test_range_containing1(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(3, 10))
        self.assertEqual(Range.closed(3, 10), range_set.range_containing(5))
        self.assertTrue(5 in range_set)
        self.assertIsNone(range_set.range_containing(1))
        self.assertFalse(1 in range_set)

    def test_add_all(self):
        range_set = RangeSet.create_mutable()
        range_set.add(Range.closed(3, 10))
        range_set.add_all(
            [Range.open(1, 3),
             Range.closed(5, 8),
             Range.closed(9, 11)])
        self.assertEqual(tuple(range_set.as_ranges()),
                         tuple([Range.open_closed(1, 11)]))

    def test_all_single_ranges_intersecting(self):
        for query in TestRangeSet.QUERY_RANGES:
            self._test_intersects(RangeSet.create_mutable().add(query))

    def test_all_two_ranges_intersecting(self):
        for query_1 in TestRangeSet.QUERY_RANGES:
            for query_2 in TestRangeSet.QUERY_RANGES:
                self._test_intersects(
                    RangeSet.create_mutable().add(query_1).add(query_2))

    # forms the basis for corresponding tests in test_range_map
    def test_rightmost_containing_or_below(self):
        range_set = RangeSet.create_mutable().add_all((
            Range.closed(-2, -1),
            Range.closed_open(0, 2),
            # we don't do [0, 2), [2.1, 3] because they will coalesce
            # ditto for (4, 5] and (5.1, 7)
            Range.closed(2.1, 3),
            Range.open_closed(4, 5),
            Range.open(5.1, 7),
        ))

        # probe value is in the middle of a set
        # [2.1  ... *2.5* ... 3]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.rightmost_containing_or_below(2.5))
        # probe value is at a closed upper limit
        # [2.1 .... *3*]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.rightmost_containing_or_below(3.0))
        # probe value is at a closed lower limit
        # [*2.1* .... 3]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.rightmost_containing_or_below(2.1))
        # probe value is at an open lower limit
        # [2.1 ... 3], (*4* ... 5]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.rightmost_containing_or_below(4.0))
        # probe value is at an open upper limit
        # [0 ... *2.1*)
        self.assertEqual(Range.closed_open(0.0, 2.0),
                         range_set.rightmost_containing_or_below(2.0))
        # probe value falls into a gap
        # [-2, -1] ... *-0.5* ... [0, 2)
        self.assertEqual(Range.closed(-2.0, -1.0),
                         range_set.rightmost_containing_or_below(-0.5))
        # no range below
        # *-3* .... [-2,-1]
        self.assertIsNone(range_set.rightmost_containing_or_below(-3.0))
        # empty rangeset
        self.assertIsNone(RangeSet.create_mutable().add(Range.closed(
            1.0, 2.0)).rightmost_containing_or_below(0.0))
        # lowest range has open lower bound
        # (*1*,2)
        self.assertIsNone(RangeSet.create_mutable().add(Range.open(
            1.0, 2.0)).rightmost_containing_or_below(1.0))

    # forms the basis for corresponding tests in test_range_set
    def test_leftmost_containing_or_above(self):
        range_set = RangeSet.create_mutable().add_all((
            Range.closed(-2, -1),
            Range.closed_open(0, 2),
            # we don't do [0, 2), [2.1, 3] because they will coalesce
            # ditto for (4, 5] and (5.1, 7)
            Range.closed(2.1, 3),
            Range.open_closed(4, 5),
            Range.open(5.1, 7),
        ))

        # probe value is in the middle of a set
        # [2.1  ... *2.5* ... 3]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.leftmost_containing_or_above(2.5))
        # probe value is at a closed upper limit
        # [2.1 .... *3*]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.leftmost_containing_or_above(3.0))
        # probe value is at a closed lower limit
        # [*2.1* .... 3]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.leftmost_containing_or_above(2.1))
        # probe value is at an open lower limit
        # [2 ... 3], (*4* ... 5]
        self.assertEqual(Range.open_closed(4.0, 5.0),
                         range_set.leftmost_containing_or_above(4.0))
        # probe value is at an open upper limit
        # [0 ... *2*) [2.1, 3.0]
        self.assertEqual(Range.closed(2.1, 3.0),
                         range_set.leftmost_containing_or_above(2.0))
        # probe value falls into a gap
        # [-2, -1] ... *-0.5* ... [0, 2)
        self.assertEqual(Range.closed_open(0, 2),
                         range_set.leftmost_containing_or_above(-0.5))
        # no range above
        # (5.1 ... 7) ... *8*
        self.assertIsNone(range_set.leftmost_containing_or_above(8))
        # empty rangeset
        self.assertIsNone(RangeSet.create_mutable().add(Range.closed(
            1.0, 2.0)).leftmost_containing_or_above(3.0))
        # higher range has open upper bound
        # (1,*2*)
        self.assertIsNone(RangeSet.create_mutable().add(Range.open(
            1.0, 2.0)).leftmost_containing_or_above(2.0))

    def test_len(self):
        self.assertEqual(0, len(RangeSet.create_mutable()))
        self.assertEqual(
            1, len(RangeSet.create_mutable().add(Range.closed(1, 2))))
        self.assertEqual(
            2,
            len(RangeSet.create_mutable().add(Range.closed(1, 2)).add(
                Range.open(3, 4))),
        )

    # support methods

    def _pair_test(self, a: Range[int], b: Range[int]) -> None:
        range_set: MutableRangeSet[int] = RangeSet.create_mutable()
        range_set.add(a)
        range_set.add(b)
        if a.is_empty() and b.is_empty():
            self.assertTrue(range_set.is_empty())
            self.assertFalse(range_set.as_ranges())
        elif a.is_empty():
            self.assertTrue(b in range_set.as_ranges())
        elif b.is_empty():
            self.assertTrue(a in range_set.as_ranges())
        elif a.is_connected(b):
            self.assertEqual(tuple(range_set.as_ranges()), tuple([a.span(b)]))
        else:
            if a.lower_endpoint < b.lower_endpoint:
                self.assertEqual(tuple(range_set.as_ranges()), tuple([a, b]))
            else:
                self.assertEqual(ImmutableSet.of([a, b]),
                                 ImmutableSet.of(range_set.as_ranges()))

    def _test_encloses(self, range_set: RangeSet[int]):
        self.assertTrue(range_set.encloses_all(ImmutableSet.empty()))
        for query_range in TestRangeSet.QUERY_RANGES:
            expected_to_enclose = any(
                x.encloses(query_range) for x in range_set.as_ranges())
            self.assertEqual(expected_to_enclose,
                             range_set.encloses(query_range))
            self.assertEqual(expected_to_enclose,
                             range_set.encloses_all([query_range]))

    def _test_intersects(self, range_set: RangeSet[int]):
        for query in TestRangeSet.QUERY_RANGES:
            expect_intersects = any(
                r.is_connected(query) and not r.intersection(query).is_empty()
                for r in range_set.as_ranges())
            self.assertEqual(expect_intersects, range_set.intersects(query))

    def _test_invariants(self, range_set: RangeSet[int]):
        self.assertEqual(len(range_set.as_ranges()) == 0, range_set.is_empty())
        as_ranges: Sequence[Range[int]] = tuple(range_set.as_ranges())

        # test that connected ranges are coalesced
        for (range_1, range_2) in tile_with_pairs(as_ranges):
            self.assertFalse(range_1.is_connected(range_2))

        for rng in as_ranges:
            self.assertFalse(rng.is_empty())

        # test that the RangeSet's span is the span of all the ranges
        if as_ranges:
            self.assertEqual(Range.create_spanning(range_set.as_ranges()),
                             range_set.span)
        else:
            with self.assertRaises(ValueError):
                # pylint: disable=pointless-statement
                # noinspection PyStatementEffect
                range_set.span

    # test internal utility functions

    def test_entry_above_below(self):
        sorted_dict = SortedDict({1: 1, 3: 3, 5: 5, 7: 7, 9: 9})
        value_at_or_below_reference = (
            (0, None),
            (1, 1),
            (2, 1),
            (3, 3),
            (4, 3),
            (5, 5),
            (6, 5),
            (7, 7),
            (8, 7),
            (9, 9),
            (10, 9),
            (200, 9),
        )
        for (key, ref) in value_at_or_below_reference:
            self.assertEqual(_value_at_or_below(sorted_dict, key), ref)

        value_below_reference = (
            (0, None),
            (1, None),
            (2, 1),
            (3, 1),
            (4, 3),
            (5, 3),
            (6, 5),
            (7, 5),
            (8, 7),
            (9, 7),
            (10, 9),
            (200, 9),
        )
        for (key, ref) in value_below_reference:
            self.assertEqual(_value_below(sorted_dict, key), ref)

        value_at_or_above_reference = (
            (0, 1),
            (1, 1),
            (2, 3),
            (3, 3),
            (4, 5),
            (5, 5),
            (6, 7),
            (7, 7),
            (8, 9),
            (9, 9),
            (10, None),
            (200, None),
        )
        for (key, ref) in value_at_or_above_reference:
            self.assertEqual(_value_at_or_above(sorted_dict, key), ref)

    def test_pickling(self):
        empty_mutable_rangeset = MutableRangeSet.create_mutable()
        empty_immutable_rangeset = ImmutableRangeSet.builder().build()
        ranges = (Range.closed(0, 2), Range.closed(5,
                                                   29), Range.closed(35, 39))
        mutable_rangeset = MutableRangeSet.create_mutable().add_all(ranges)
        immutable_rangeset = ImmutableRangeSet.builder().add_all(
            ranges).build()

        self.assertEqual(empty_mutable_rangeset,
                         pickle.loads(pickle.dumps(empty_mutable_rangeset)))
        self.assertEqual(empty_immutable_rangeset,
                         pickle.loads(pickle.dumps(empty_immutable_rangeset)))
        self.assertEqual(mutable_rangeset,
                         pickle.loads(pickle.dumps(mutable_rangeset)))
        self.assertEqual(immutable_rangeset,
                         pickle.loads(pickle.dumps(immutable_rangeset)))

        self.assertEqual(empty_mutable_rangeset.__getstate__(), ())
        self.assertEqual(empty_immutable_rangeset.__getstate__(), ())

        self.assertEqual(mutable_rangeset.__getstate__(), ranges)
        self.assertEqual(immutable_rangeset.__getstate__(), ranges)
Exemple #6
0
_T = TypeVar("_T")


# TODO: move this to vistautils
def _in_range(_range: Range[_T]) -> Callable[[Any, Any, Any], None]:
    def validator(obj, attribute: Attribute, value) -> None:  # type: ignore
        if value not in _range:
            raise ValueError(
                f"Attribute {attribute.name}'s value is not in required range {_range} for object"
                f" of type {type(obj)}")

    return validator


_positive = _in_range(Range.greater_than(0.0))  # pylint:disable=invalid-name
_non_negative = _in_range(Range.at_least(0.0))  # pylint:disable=invalid-name
_degrees = _in_range(Range.closed_open(-360.0, 360.0))  # pylint:disable=invalid-name


@attrs(frozen=True)
class Cylinder:
    """
    A cylinder, irrespective of orientation.

    Marr's representation builds objects up from generalized cylinders; right now we only
    represent cylinders with circular cross-sections.
    """

    length_in_meters: float = attrib(validator=_positive, kw_only=True)
    diameter_in_meters: float = attrib(validator=_positive, kw_only=True)
Exemple #7
0
class AbstractCrossSituationalLearner(AbstractTemplateLearnerNew, ABC):
    """
    An Abstract Implementation of the Cross Situation Learning Model

    This learner aims to learn via storing all possible meanings and narrowing down to one meaning
    by calculating association scores and probability based off those association scores for each
    utterance situation pair. It does so be associating all words to certain meanings. For new words
    meanings that are not associated strongly to another word already are associated evenly. For
    words encountered before, words are associated more strongly to meanings encountered with that
    word before and less strongly to newer meanings. Lastly, very familiar word meaning pairs are
    associated together only, these would be words generally considered lexicalized. Once
    associations are made a probability for each word meaning pair being correct is calculated.
    Finally if the probability is high enough the word is lexicalized. More information can be
    found here: https://onlinelibrary.wiley.com/doi/full/10.1111/j.1551-6709.2010.01104.x
    """
    @attrs(slots=True, eq=False, frozen=True)
    class Hypothesis:
        pattern_template: PerceptionGraphTemplate = attrib(
            validator=instance_of(PerceptionGraphTemplate))
        association_score: float = attrib(validator=instance_of(float),
                                          default=0)
        probability: float = attrib(validator=in_(Range.open(0, 1)), default=0)
        observation_count: int = attrib(default=1)

    _ontology: Ontology = attrib(validator=instance_of(Ontology), kw_only=True)
    _observation_num = attrib(init=False, default=0)
    _surface_template_to_concept: Dict[SurfaceTemplate,
                                       Concept] = attrib(init=False,
                                                         default=Factory(dict))
    _concept_to_surface_template: Dict[Concept, SurfaceTemplate] = attrib(
        init=False, default=Factory(dict))
    _concept_to_hypotheses: ImmutableDict[
        Concept,
        ImmutableSet["AbstractCrossSituationalLearner.Hypothesis"]] = attrib(
            init=False, default=Factory(dict))

    # Learner Internal Values
    _smoothing_parameter: float = attrib(validator=in_(
        Range.greater_than(0.0)),
                                         kw_only=True)
    """
    This smoothing factor is added to the scores of all hypotheses
    when forming a probability distribution over hypotheses.
    This should be a small value, at most 0.1 and possibly much less.
    See section 3.3 of the Cross-Situational paper.
    """
    _expected_number_of_meanings: float = attrib(validator=in_(
        Range.greater_than(0.0)),
                                                 kw_only=True)
    _graph_match_confirmation_threshold: float = attrib(default=0.8,
                                                        kw_only=True)
    _lexicon_entry_threshold: float = attrib(default=0.8, kw_only=True)
    _minimum_observation_amount: int = attrib(default=5, kw_only=True)

    _concepts_in_utterance: ImmutableSet[Concept] = attrib(
        init=False, default=ImmutableSet)
    _updated_hypotheses: Dict[Concept,
                              ImmutableSet[Hypothesis]] = attrib(init=False,
                                                                 factory=dict)

    # Corresponds to the dummy word from the paper
    _dummy_concept: Concept = attrib(init=False)

    # Debug Values
    _debug_callback: Optional[DebugCallableType] = attrib(default=None)
    _graph_logger: Optional[GraphLogger] = attrib(validator=optional(
        instance_of(GraphLogger)),
                                                  default=None)

    @_dummy_concept.default
    def _init_dummy_concept(self):
        return self._new_concept("_cross_situational_dummy_concept")

    def _pre_learning_step(
        self, language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment
    ) -> None:
        # Figure out what "words" (concepts) appear in the utterance.
        concepts_in_utterance = []
        for other_bound_surface_template in self._candidate_templates(
                language_perception_semantic_alignment):
            # We have seen this template before and already have a concept for it
            if (other_bound_surface_template.surface_template
                    in self._surface_template_to_concept):
                concept = self._surface_template_to_concept[
                    other_bound_surface_template.surface_template]
            # Otherwise, make a new concept for it
            else:
                concept = self._new_concept(
                    debug_string=other_bound_surface_template.surface_template.
                    to_short_string())
            self._surface_template_to_concept[
                other_bound_surface_template.surface_template] = concept
            self._concept_to_surface_template[
                concept] = other_bound_surface_template.surface_template
            concepts_in_utterance.append(concept)
        self._concepts_in_utterance = immutableset(concepts_in_utterance)

        # We only need to make a shallow copy of our old hypotheses
        # because the values of self._concept_to_hypotheses are immutable.
        self._updated_hypotheses = dict(self._concept_to_hypotheses)

    def _learning_step(
        self,
        language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment,
        bound_surface_template: SurfaceTemplateBoundToSemanticNodes,
    ) -> None:
        """
        Try to learn the semantics of a `SurfaceTemplate` given the assumption
        that its argument slots (if any) are bound to objects according to
        *bound_surface_template*.

        For example, "try to learn the meaning of 'red' given the language 'red car'
        and an alignment of 'car' to particular perceptions in the perception graph.
        """
        # Generate all possible meanings from the Graph
        meanings_from_perception = immutableset(
            self._hypotheses_from_perception(
                language_perception_semantic_alignment,
                bound_surface_template))
        meanings_to_pattern_template: Mapping[
            PerceptionGraph, PerceptionGraphTemplate] = immutabledict(
                (meaning,
                 PerceptionGraphTemplate.from_graph(meaning, immutabledict()))
                for meaning in meanings_from_perception)

        # We check for meanings that are described by lexicalized concepts
        # and don't try to learn those lexicalized concepts further.
        # jac: Not mentioned in the part of the paper I read. New?
        concepts_to_remove: Set[Concept] = set()

        def check_and_remove_meaning(
            other_concept: Concept,
            hypothesis: "AbstractCrossSituationalLearner.Hypothesis",
            *,
            ontology: Ontology,
        ) -> None:
            match = compute_match_ratio(
                hypothesis.pattern_template,
                language_perception_semantic_alignment.
                perception_semantic_alignment.perception_graph,
                ontology=ontology,
            )
            if match and match.matching_subgraph:
                for meaning in meanings_from_perception:
                    if match.matching_subgraph.check_isomorphism(
                            meanings_to_pattern_template[meaning].graph_pattern
                    ):
                        concepts_to_remove.add(other_concept)

        for (other_concept, hypotheses) in self._concept_to_hypotheses.items():
            for hypothesis in hypotheses:
                if hypothesis.probability > self._lexicon_entry_threshold:
                    check_and_remove_meaning(other_concept,
                                             hypothesis,
                                             ontology=self._ontology)

        # We have seen this template before and already have a concept for it
        # So we attempt to verify our already picked concept
        if bound_surface_template.surface_template in self._surface_template_to_concept:
            # We don't directly associate surface templates with perceptions.
            # Instead we mediate the relationship with "concept" objects.
            # These don't matter now, but the split might be helpful in the future
            # when we might have multiple ways of expressing the same idea.
            concept = self._surface_template_to_concept[
                bound_surface_template.surface_template]
        else:
            concept = self._new_concept(debug_string=bound_surface_template.
                                        surface_template.to_short_string())
        self._surface_template_to_concept[
            bound_surface_template.surface_template] = concept
        self._concept_to_surface_template[
            concept] = bound_surface_template.surface_template

        concepts_after_preprocessing = immutableset([
            concept for concept in self._concepts_in_utterance
            if concept not in concepts_to_remove
            # TODO Does it make sense to include a dummy concept/"word"? The paper has one so I
            #  am including it for now.
        ] + [self._dummy_concept])

        # Step 0. Update priors for any meanings as-yet unobserved.

        # Step 1. Compute alignment probabilities (pp. 1029)
        # We have an identified "word" (concept) from U(t)
        # and a collection of meanings from the scene S(t).
        # We now want to calculate the alignment probabilities,
        # which will be used to update this concept's association scores, assoc(w|m, U(t), S(t)),
        # and meaning probabilities, p(m|w).
        alignment_probabilities = self._get_alignment_probabilities(
            concepts_after_preprocessing, meanings_from_perception)

        # We have an identified "word" (concept) from U(t)
        # and a collection of meanings from the scene S(t).
        # We now want to update p(.|w), which means calculating the probabilities.
        new_hypotheses = self._updated_meaning_probabilities(
            concept,
            meanings_from_perception,
            meanings_to_pattern_template,
            alignment_probabilities,
        )

        # Finally, update our hypotheses for this concept
        self._updated_hypotheses[concept] = new_hypotheses

    def _post_learning_step(
        self, language_perception_semantic_alignment:
        LanguagePerceptionSemanticAlignment
    ) -> None:
        # Finish updating hypotheses
        # We have to do this as a separate step
        # so that we can update our hypotheses for each concept
        # independently of the hypotheses for the other concepts,
        # as in the algorithm described by the paper.
        self._concept_to_hypotheses = immutabledict(self._updated_hypotheses)
        self._updated_hypotheses.clear()

    def _get_alignment_probabilities(
        self, concepts: Iterable[Concept],
        meanings: ImmutableSet[PerceptionGraph]
    ) -> ImmutableDict[Concept, ImmutableDict[PerceptionGraph, float]]:
        """
        Compute the concept-(concrete meaning) alignment probabilities for a given word
        as defined by the paper below:

        a(m|c, U(t), S(t)) = (p^(t-1)(m|c)) / sum(for c' in (U^(t) union {d}))

        where c and m are given concept and meanings, lambda is a smoothing factor, M is all
        meanings encountered, beta is an upper bound on the expected number of meaning types.
        https://onlinelibrary.wiley.com/doi/full/10.1111/j.1551-6709.2010.01104.x (3)
        """
        def meaning_probability(meaning: PerceptionGraph,
                                concept: Concept) -> float:
            """
            Return the meaning probability p^(t-1)(m|c).
            """
            # If we've already observed this concept before,
            if concept in self._concept_to_hypotheses:
                # And if we've already observed this meaning before,
                maybe_ratio_with_preexisting_hypothesis = self._find_similar_hypothesis(
                    meaning, self._concept_to_hypotheses[concept])
                if maybe_ratio_with_preexisting_hypothesis:
                    # return the prior probability.
                    _, preexisting_hypothesis = maybe_ratio_with_preexisting_hypothesis
                    return preexisting_hypothesis.probability
                # Otherwise, if we have observed this concept before
                # but not paired with a perception like this meaning,
                # it is assigned zero probability.
                # Is this correct?
                else:
                    return 0.0
            # If we haven't observed this concept before,
            # its prior probability is evenly split among all the observed meanings in this perception.
            else:
                return 1.0 / len(meanings)

        meaning_to_concept_to_alignment_probability: Dict[
            PerceptionGraph, ImmutableDict[Concept, float]] = dict()
        for meaning in iter(meanings):
            # We want to calculate the alignment probabilities for each concept against this meaning.
            # First, we compute the prior meaning probabilities p(m|c),
            # the probability that the concept c means m for each meaning m observed in the scene.
            concept_to_meaning_probability: Mapping[Concept,
                                                    float] = immutabledict({
                                                        concept:
                                                        meaning_probability(
                                                            meaning, concept)
                                                        for concept in concepts
                                                    })
            total_probability_mass: float = sum(
                concept_to_meaning_probability.values())

            # We use these to calculate the alignment probabilities a(c|m, U(t), S(t)).
            meaning_to_concept_to_alignment_probability[
                meaning] = immutabledict({
                    concept: meaning_probability_ / total_probability_mass
                    for concept, meaning_probability_ in
                    concept_to_meaning_probability.items()
                })

        # Restructure meaning_to_concept_to_alignment_probability
        # to get a map concept_to_meaning_to_alignment_probability.
        return immutabledict([
            (concept, immutabledict([(meaning, alignment_probability)]))
            for meaning, concept_to_alignment_probability in
            meaning_to_concept_to_alignment_probability.items() for concept,
            alignment_probability in concept_to_alignment_probability.items()
        ])

    def _updated_meaning_probabilities(
        self,
        concept: Concept,
        meanings: Iterable[PerceptionGraph],
        meaning_to_pattern: Mapping[PerceptionGraph, PerceptionGraphTemplate],
        alignment_probabilities: Mapping[Concept, Mapping[PerceptionGraph,
                                                          float]],
    ) -> ImmutableSet["AbstractCrossSituationalLearner.Hypothesis"]:
        """
        Update all concept-(abstract meaning) probabilities for a given word
        as defined by the paper below:

        p(m|c) = (assoc(m, c) + lambda) / (sum(for m' in M)(assoc(c, m)) + (beta * lambda))

        where c and m are given concept and meanings, lambda is a smoothing factor, M is all
        meanings encountered, beta is an upper bound on the expected number of meaning types.
        https://onlinelibrary.wiley.com/doi/full/10.1111/j.1551-6709.2010.01104.x (3)
        """
        old_hypotheses = self._concept_to_hypotheses.get(
            concept, immutableset())

        # First we calculate the new association scores for each observed meaning.
        # If a meaning was not observed this instance, we don't change its association score at all.
        updated_hypotheses: Set[
            "AbstractCrossSituationalLearner.Hypothesis"] = set()
        hypothesis_updates: List[
            "AbstractCrossSituationalLearner.Hypothesis"] = []
        for meaning in meanings:
            # We use a placeholder probability to keep the hypothesis constructor happy.
            # We are going to fix up this probability later.
            placeholder_probability = 0.5
            # First, check if we've observed this meaning before.
            ratio_similar_hypothesis_pair = self._find_similar_hypothesis(
                meaning, old_hypotheses)
            if ratio_similar_hypothesis_pair is not None:
                ratio, similar_hypothesis = ratio_similar_hypothesis_pair

                # If we *have* observed this meaning before,
                # we need to update the existing hypothesis for it.
                if ratio.match_ratio > self._graph_match_confirmation_threshold:
                    # Mark the old hypothesis as updated
                    # so we don't include both the old and new hypothesis in our output.
                    updated_hypotheses.add(similar_hypothesis)
                    new_association_score = (
                        similar_hypothesis.association_score +
                        alignment_probabilities[concept][meaning])
                    new_observation_count = similar_hypothesis.observation_count + 1
                    new_hypothesis = AbstractCrossSituationalLearner.Hypothesis(
                        pattern_template=similar_hypothesis.pattern_template,
                        association_score=new_association_score,
                        probability=placeholder_probability,
                        observation_count=new_observation_count,
                    )
                    hypothesis_updates.append(new_hypothesis)
                    continue

            # If we *haven't* observed this meaning before,
            # we need to create a new hypothesis for it.
            new_hypothesis = AbstractCrossSituationalLearner.Hypothesis(
                pattern_template=meaning_to_pattern[meaning],
                association_score=alignment_probabilities[concept][meaning],
                probability=placeholder_probability,
                observation_count=1,
            )
            hypothesis_updates.append(new_hypothesis)

        # Now we calculate the updated meaning probabilities p(m|w).
        total_association_score = sum(hypothesis.association_score
                                      for hypothesis in hypothesis_updates)
        smoothing_term = self._expected_number_of_meanings * self._smoothing_parameter
        return immutableset(
            chain(
                # Include old hypotheses that weren't updated
                [
                    old_hypothesis for old_hypothesis in old_hypotheses
                    if old_hypothesis not in updated_hypotheses
                ],
                # Include new and updated hypotheses
                [
                    evolve(
                        hypothesis,
                        # Replace the placeholder meaning probability with the true meaning probability,
                        # calculated using the association scores and smoothing term.
                        probability=(hypothesis.association_score +
                                     self._smoothing_parameter) /
                        (total_association_score + smoothing_term),
                    ) for hypothesis in hypothesis_updates
                ],
            ))

    def _find_similar_hypothesis(
        self,
        new_meaning: PerceptionGraph,
        candidates: Iterable["AbstractCrossSituationalLearner.Hypothesis"],
    ) -> Optional[Tuple[PartialMatchRatio,
                        "AbstractCrossSituationalLearner.Hypothesis"]]:
        """
        Finds the hypothesis in candidates most similar to new_meaning and returns it
        together with the match ratio.

        Returns None if no candidate can be found that is sufficiently similar to new_meaning. A candidate is
        sufficiently similar if and only if its match ratio with new_meaning is at least
        _graph_match_confirmation_threshold.
        """
        candidates_iter = iter(candidates)
        match = None
        while match is None:
            try:
                existing_hypothesis = next(candidates_iter)
            except StopIteration:
                return None

            try:
                match = compute_match_ratio(
                    existing_hypothesis.pattern_template,
                    new_meaning,
                    ontology=self._ontology,
                )
            except RuntimeError:
                # Occurs when no matches of the pattern are found in the graph. This seems to
                # to indicate some full matches and some matches with no intersection at all
                pass

        for candidate in candidates:
            try:
                new_match = compute_match_ratio(candidate.pattern_template,
                                                new_meaning,
                                                ontology=self._ontology)
            except RuntimeError:
                # Occurs when no matches of the pattern are found in the graph. This seems to
                # to indicate some full matches and some matches with no intersection at all
                new_match = None
            if new_match and new_match.match_ratio > match.match_ratio:
                match = new_match
                existing_hypothesis = candidate
        if (match.match_ratio >= self._graph_match_confirmation_threshold
                and match.matching_subgraph and existing_hypothesis):
            return match, existing_hypothesis
        else:
            return None

    def templates_for_concept(
            self, concept: Concept) -> AbstractSet[SurfaceTemplate]:
        if concept in self._concept_to_surface_template:
            return immutableset([self._concept_to_surface_template[concept]])
        else:
            return immutableset()

    def concepts_to_patterns(self) -> Dict[Concept, PerceptionGraphPattern]:
        def argmax(hypotheses):
            # TODO is this key correct? what IS our "best hypothesis"?
            return max(
                hypotheses,
                key=lambda hypothesis: (
                    hypothesis.probability,
                    hypothesis.association_score,
                ),
            )

        return {
            concept: argmax(hypotheses).pattern_template.graph_pattern
            for concept, hypotheses in self._concept_to_hypotheses.items()
        }

    @abstractmethod
    def _new_concept(self, debug_string: str) -> Concept:
        """
        Create a new `Concept` of the appropriate type with the given *debug_string*.
        """

    @abstractmethod
    def _hypotheses_from_perception(
        self,
        learning_state: LanguagePerceptionSemanticAlignment,
        bound_surface_template: SurfaceTemplateBoundToSemanticNodes,
    ) -> Iterable[PerceptionGraph]:
        """
        Get a hypothesis for the meaning of *surface_template* from a given *learning_state*.
        """

    def _primary_templates(
            self) -> Iterable[Tuple[Concept, PerceptionGraphTemplate, float]]:
        return (
            (concept, hypothesis.pattern_template, hypothesis.probability)
            for (concept, hypotheses) in self._concept_to_hypotheses.items()
            # We are confident in a hypothesis if it's above our _lexicon_entry_threshold
            # and we've seen this concept our _minimum_observation_amount
            for hypothesis in hypotheses
            if hypothesis.observation_count >= self._minimum_observation_amount
            and hypothesis.probability >= self._lexicon_entry_threshold)

    def _fallback_templates(
            self) -> Iterable[Tuple[Concept, PerceptionGraphTemplate, float]]:
        # Alternate hypotheses either below our _lexicon_entry_threshold or our _minimum_observation_amount
        return (
            (concept, hypothesis.pattern_template, hypothesis.probability)
            for (concept, hypotheses) in self._concept_to_hypotheses.items()
            for hypothesis in sorted(
                hypotheses,
                key=lambda hypothesis: hypothesis.probability,
                reverse=True)
            if hypothesis.observation_count < self._minimum_observation_amount
            or hypothesis.probability < self._lexicon_entry_threshold)

    def _match_template(
        self,
        *,
        concept: Concept,
        pattern: PerceptionGraphTemplate,
        perception_graph: PerceptionGraph,
    ) -> Iterable[Tuple[PerceptionGraphPatternMatch, SemanticNode]]:
        """
        Try to match our model of the semantics to the perception graph
        """
        partial_match = compute_match_ratio(
            pattern,
            perception_graph,
            ontology=self._ontology,
            graph_logger=self._graph_logger,
            debug_callback=self._debug_callback,
        )

        if (partial_match.match_ratio >=
                self._graph_match_confirmation_threshold
                and partial_match.matching_subgraph):
            # if there is a match, which is above our minimum match ratio
            # Use that pattern to try and find a match in the scene
            # There should be one
            # TODO: This currently means we match to the graph multiple times. Reduce this?
            matcher = partial_match.matching_subgraph.matcher(
                perception_graph,
                match_mode=MatchMode.NON_OBJECT,
                debug_callback=self._debug_callback,
            )
            found_match = False
            for match in matcher.matches(use_lookahead_pruning=True):
                found_match = True
                semantic_node_for_match = pattern_match_to_semantic_node(
                    concept=concept, pattern=pattern, match=match)
                yield match, semantic_node_for_match
            # We raise an error if we find a partial match but don't manage to match it to the scene
            if not found_match:
                raise RuntimeError(
                    f"Partial Match found for {concept} below match ratio however pattern "
                    f"subgraph was unable to match to perception graph.\n"
                    f"Partial Match: {partial_match}\n"
                    f"Perception Graph: {perception_graph}")
Exemple #8
0
 def test_equals(self):
     self.assertEqual(Range.all(), Range.all())
     self.assertEqual(Range.greater_than(2), Range.greater_than(2))
     self.assertEqual(Range.open(1, 5), Range.open(1, 5))