Ejemplo n.º 1
0
    def predict_zero_shot_2afc(self, sentence, model1, model2):
        """
    Yield zero-shot predictions on a 2AFC sentence, marginalizing over possible
    novel lexical entries required to parse the sentence.

    TODO explain marginalization process in more detail

    Args:
      sentence: List of token strings
      models:

    Returns:
      model_scores: `Distribution` over scene models (with support `models`),
        `p(referred scene | sentence)`
    """
        parser = chart.WeightedCCGChartParser(self.lexicon)
        weighted_results = parser.parse(sentence, True)
        if len(weighted_results) == 0:
            L.warning("Parse failed for sentence '%s'", " ".join(sentence))

            aug_lexicon = self.do_lexical_induction(
                sentence, (model1, model2),
                augment_lexicon_fn=augment_lexicon_2afc,
                queue_limit=50)
            parser = chart.WeightedCCGChartParser(aug_lexicon)
            weighted_results = parser.parse(sentence, True)

        dist = Distribution()

        for result, score, _ in weighted_results:
            semantics = result.label()[0].semantics()
            try:
                model1_pass = model1.evaluate(semantics) == True
            except:
                pass
            else:
                if model1_pass:
                    dist[model1] += np.exp(score)

            try:
                model2_pass = model2.evaluate(semantics) == True
            except:
                pass
            else:
                if model2_pass:
                    dist[model2] += np.exp(score)

        return dist.ensure_support((model1, model2)).normalize()
Ejemplo n.º 2
0
def test_zero_shot_type_request_2arg():
  """
  predict_zero_shot should infer the types of missing semantic forms and use as
  specific a possible type request when invoking `Ontology.iter_expressions`
  """
  ontology = _make_simple_mock_ontology()
  lex = Lexicon.fromstring(r"""
  :- S, N
  bar => N {qux}
  blah => N {quz}
  # dummy => S\N/N {\x y.threeplace(x,y,baz)}
  """, ontology=ontology, include_semantics=True)

  # setup: we observe a sentence "blah foo bar". ground truth semantics for 'foo' is
  # \x y.threeplace(x,y,baz)

  # Mock ontology.predict_zero_shot
  mock = MagicMock(return_value=[])
  ontology.iter_expressions = mock

  tokens = ["foo"]
  candidate_syntaxes = {"foo": Distribution.uniform([lex.parse_category(r"S\N/N")])}
  sentence = "blah foo bar".split()

  predict_zero_shot(lex, tokens, candidate_syntaxes, sentence, ontology,
                    model=None, likelihood_fns=[])

  eq_(len(mock.call_args_list), 1)
  args, kwargs = mock.call_args
  eq_(kwargs["type_request"], ontology.types["obj", "obj", "*"])
Ejemplo n.º 3
0
    def total_category_masses(self,
                              exclude_tokens=frozenset(),
                              soft_propagate_roots=False):
        """
    Return the total weight mass assigned to each syntactic category. Shifts
    masses such that the minimum mass is zero.

    Args:
      exclude_tokens: Exclude entries with this token from the count.
      soft_propagate_roots: Soft-propagate derived root categories. If there is
        a derived root category `D0{S}` and some lexical entry `S/N`, even if
        no entry has the category `D0{S}`, we will add a key to the returned
        counter with category `D0{S}/N` (and zero weight).

    Returns:
      masses: `Distribution` mapping from category types to masses. The minimum
        mass value is zero and the maximum is unbounded.
    """
        ret = Distribution()
        # Track categories with root yield.
        rooted_cats = set()

        for token, entries in self._entries.items():
            if token in exclude_tokens:
                continue
            for entry in entries:
                c_yield = get_yield(entry.categ())
                if c_yield in self._starts:
                    rooted_cats.add((c_yield, entry.categ()))

                if entry.weight() > 0.0:
                    ret[entry.categ()] += entry.weight()

        if soft_propagate_roots:
            for c_yield, rooted_cat in rooted_cats:
                for derived_root_cat in self._derived_categories_by_base[
                        c_yield]:
                    soft_prop_cat = set_yield(rooted_cat, derived_root_cat)
                    # Ensure key exists.
                    ret.setdefault(soft_prop_cat, 0.0)

        return ret
Ejemplo n.º 4
0
    def lf_ngrams_mixed(self, alpha=0.25, **kwargs):
        """
    Return conditional distributions over logical form n-grams conditioned on
    syntactic category, calculated by mixing two distribution classes: a
    distribution conditioned on the full syntactic category and a distribution
    conditioned on the yield of the category.
    """
        lf_syntax_ngrams = self.lf_ngrams_given_syntax(**kwargs)
        lf_support = lf_syntax_ngrams.support

        # Soft-propagate derived root categories.
        for syntax in list(lf_syntax_ngrams.dists.keys()):
            syn_yield = get_yield(syntax)
            if syn_yield in self._starts:
                for derived_root_cat in self._derived_categories_by_base[
                        syn_yield]:
                    new_yield = set_yield(syntax, derived_root_cat)
                    if new_yield not in lf_syntax_ngrams:
                        lf_syntax_ngrams[new_yield] = Distribution.uniform(
                            lf_support)

        # Second distribution: P(pred | root)
        lf_yield_ngrams = self.lf_ngrams(
            conditioning_fn=lambda entry: [get_yield(entry.categ())], **kwargs)
        # Mix full-category and primitive-category predictions.
        lf_mixed_ngrams = ConditionalDistribution()
        for syntax in lf_syntax_ngrams:
            # # Mix distributions conditioned on the constituent primitives.
            # primitives = get_category_primitives(syntax)
            # prim_alpha = 1 / len(primitives)

            # Mix root-conditioned distribution and the full syntax-conditioned
            # distribution.
            yield_dist = lf_yield_ngrams[get_yield(syntax)]
            lf_mixed_ngrams[syntax] = lf_syntax_ngrams[syntax].mix(
                yield_dist, alpha)

        return lf_mixed_ngrams