def predict_zero_shot_2afc(self, sentence, model1, model2): """ Yield zero-shot predictions on a 2AFC sentence, marginalizing over possible novel lexical entries required to parse the sentence. TODO explain marginalization process in more detail Args: sentence: List of token strings models: Returns: model_scores: `Distribution` over scene models (with support `models`), `p(referred scene | sentence)` """ parser = chart.WeightedCCGChartParser(self.lexicon) weighted_results = parser.parse(sentence, True) if len(weighted_results) == 0: L.warning("Parse failed for sentence '%s'", " ".join(sentence)) aug_lexicon = self.do_lexical_induction( sentence, (model1, model2), augment_lexicon_fn=augment_lexicon_2afc, queue_limit=50) parser = chart.WeightedCCGChartParser(aug_lexicon) weighted_results = parser.parse(sentence, True) dist = Distribution() for result, score, _ in weighted_results: semantics = result.label()[0].semantics() try: model1_pass = model1.evaluate(semantics) == True except: pass else: if model1_pass: dist[model1] += np.exp(score) try: model2_pass = model2.evaluate(semantics) == True except: pass else: if model2_pass: dist[model2] += np.exp(score) return dist.ensure_support((model1, model2)).normalize()
def test_zero_shot_type_request_2arg(): """ predict_zero_shot should infer the types of missing semantic forms and use as specific a possible type request when invoking `Ontology.iter_expressions` """ ontology = _make_simple_mock_ontology() lex = Lexicon.fromstring(r""" :- S, N bar => N {qux} blah => N {quz} # dummy => S\N/N {\x y.threeplace(x,y,baz)} """, ontology=ontology, include_semantics=True) # setup: we observe a sentence "blah foo bar". ground truth semantics for 'foo' is # \x y.threeplace(x,y,baz) # Mock ontology.predict_zero_shot mock = MagicMock(return_value=[]) ontology.iter_expressions = mock tokens = ["foo"] candidate_syntaxes = {"foo": Distribution.uniform([lex.parse_category(r"S\N/N")])} sentence = "blah foo bar".split() predict_zero_shot(lex, tokens, candidate_syntaxes, sentence, ontology, model=None, likelihood_fns=[]) eq_(len(mock.call_args_list), 1) args, kwargs = mock.call_args eq_(kwargs["type_request"], ontology.types["obj", "obj", "*"])
def total_category_masses(self, exclude_tokens=frozenset(), soft_propagate_roots=False): """ Return the total weight mass assigned to each syntactic category. Shifts masses such that the minimum mass is zero. Args: exclude_tokens: Exclude entries with this token from the count. soft_propagate_roots: Soft-propagate derived root categories. If there is a derived root category `D0{S}` and some lexical entry `S/N`, even if no entry has the category `D0{S}`, we will add a key to the returned counter with category `D0{S}/N` (and zero weight). Returns: masses: `Distribution` mapping from category types to masses. The minimum mass value is zero and the maximum is unbounded. """ ret = Distribution() # Track categories with root yield. rooted_cats = set() for token, entries in self._entries.items(): if token in exclude_tokens: continue for entry in entries: c_yield = get_yield(entry.categ()) if c_yield in self._starts: rooted_cats.add((c_yield, entry.categ())) if entry.weight() > 0.0: ret[entry.categ()] += entry.weight() if soft_propagate_roots: for c_yield, rooted_cat in rooted_cats: for derived_root_cat in self._derived_categories_by_base[ c_yield]: soft_prop_cat = set_yield(rooted_cat, derived_root_cat) # Ensure key exists. ret.setdefault(soft_prop_cat, 0.0) return ret
def lf_ngrams_mixed(self, alpha=0.25, **kwargs): """ Return conditional distributions over logical form n-grams conditioned on syntactic category, calculated by mixing two distribution classes: a distribution conditioned on the full syntactic category and a distribution conditioned on the yield of the category. """ lf_syntax_ngrams = self.lf_ngrams_given_syntax(**kwargs) lf_support = lf_syntax_ngrams.support # Soft-propagate derived root categories. for syntax in list(lf_syntax_ngrams.dists.keys()): syn_yield = get_yield(syntax) if syn_yield in self._starts: for derived_root_cat in self._derived_categories_by_base[ syn_yield]: new_yield = set_yield(syntax, derived_root_cat) if new_yield not in lf_syntax_ngrams: lf_syntax_ngrams[new_yield] = Distribution.uniform( lf_support) # Second distribution: P(pred | root) lf_yield_ngrams = self.lf_ngrams( conditioning_fn=lambda entry: [get_yield(entry.categ())], **kwargs) # Mix full-category and primitive-category predictions. lf_mixed_ngrams = ConditionalDistribution() for syntax in lf_syntax_ngrams: # # Mix distributions conditioned on the constituent primitives. # primitives = get_category_primitives(syntax) # prim_alpha = 1 / len(primitives) # Mix root-conditioned distribution and the full syntax-conditioned # distribution. yield_dist = lf_yield_ngrams[get_yield(syntax)] lf_mixed_ngrams[syntax] = lf_syntax_ngrams[syntax].mix( yield_dist, alpha) return lf_mixed_ngrams