def forward(self, parse, sentence_meta=None): if sentence_meta is None or sentence_meta.get("frame_str", None) is None: raise ValueError( "FrameSemanticsScorer requires a sentence_meta key frame_str") frame = sentence_meta["frame_str"] try: frame_idx = self.frame_to_idx[frame] except KeyError: raise ValueError("Unknown frame string %s" % frame) ret = self.frame_dist(self.frame_to_idx[frame]) predicate_logps = F.log_softmax(ret) score = T.zeros(()) try: root_verb = next(tok for _, tok in parse.pos() if str(tok.categ()) in self.root_types) except: root_verb = None return score for predicate in root_verb.semantics().predicates(): score += predicate_logps[self.predicate_to_idx[predicate]] for predicate in root_verb.semantics().constants(): predicate = l.Variable(predicate.name) score += predicate_logps[self.predicate_to_idx[predicate]] return score
def __init__(self, lexicon, frames, root_types=(r"(S\N)", r"((S\N)/N)")): """ Args: lexicon: frames: Collection of all possible frame strings root_types: CCG syntactic type strings of lexical entries for which we are collecting frames """ super().__init__(lexicon) self.frames = frames self.frame_to_idx = { frame: T.tensor(idx, requires_grad=False) for idx, frame in enumerate(sorted(self.frames)) } self.root_types = set(root_types) self.gradients_disabled = False ontology = self._lexicon.ontology self.predicates = [ l.Variable(val.name) for val in ontology.functions + ontology.constants ] self.predicate_to_idx = { pred: idx for idx, pred in enumerate(sorted(self.predicates)) } # Represent unnormalized frame distributions as an embedding layer self.frame_dist = nn.Embedding(len(self.frames), len(self.predicates)) nn.init.zeros_(self.frame_dist.weight)
def test_attempt_candidate_parse(): """ Find parse candidates even when the parse requires composition. """ lex = Lexicon.fromstring(r""" :- S, N gives => S\N/N/N {\o x y.give(x, y, o)} John => N {\x.John(x)} Mark => N {\x.Mark(x)} it => N {\x.T} """, include_semantics=True) # TODO this doesn't actually require composition .. get one which does cand_category = lex.parse_category(r"S\N/N/N") cand_expressions = [l.Expression.fromstring(r"\o x y.give(x,y,o)")] dummy_vars = {"sends": l.Variable("F000")} results = attempt_candidate_parse(lex, ["sends"], [cand_category], "John sends Mark it".split(), dummy_vars) ok_(len(list(results)) > 0)
def combine(self, function, arg): if not (function.categ().is_primitive() and arg.categ().is_function() and arg.categ().res().is_function()): return # Type-raising matches only the innermost application. arg = innermostFunction(arg.categ()) subs = function.categ().can_unify(arg.arg()) if subs is not None: xcat = arg.res().substitute(subs) categ = FunctionalCategory( xcat, FunctionalCategory(xcat, function.categ(), arg.dir()), -(arg.dir())) # compute semantics semantics = None if function.semantics() is not None: core = deepcopy(function.semantics()) parent = None while isinstance(core, l.LambdaExpression): parent = core core = core.term var = l.Variable("F") while var in core.free(): var = l.unique_variable(pattern=var) core = l.ApplicationExpression( l.FunctionVariableExpression(var), core) if parent is not None: parent.term = core else: semantics = core semantics = l.LambdaExpression(var, semantics) yield categ, semantics
def predict_zero_shot(lex, tokens, candidate_syntaxes, sentence, ontology, model, likelihood_fns, queue_limit=5): """ Make zero-shot predictions of the posterior `p(syntax, meaning | sentence)` for each of `tokens`. Args: lex: tokens: candidate_syntaxes: sentence: ontology: model: likelihood_fns: Collection of likelihood functions `p(meanings | syntaxes, sentence, model)` used to score candidate meaning--syntax settings for a subset of `tokens`. Each function should accept arguments `(tokens, candidate_categories, candidate_meanings, candidate_semantic_parse, model)`, where `tokens` are assigned specific categories given in `candidate_categories` and specific meanings given in `candidate_meanings`, yielding a single semantic analysis of the sentence `candidate_semantic_parse`. The function should return a log-likelihood `p(candidate_meanings | candidate_syntaxes, sentence, model)`. Returns: queues: A dictionary mapping each query token to a ranked sequence of candidates of the form `(logprob, (tokens, candidate_categories, candidate_semantics))`, describing a nonzero-probability novel mapping of a subset `tokens` to syntactic categories `candidate_categories` and meanings `candidate_semantics`. The log-probability value given is `p(meanings, syntaxes | sentence, model)`, under the relevant provided meaning likelihoods and the lexicon's distribution over syntactic forms. dummy_vars: TODO """ get_arity = (lex.ontology and lex.ontology.get_expr_arity) \ or get_semantic_arity # We will restrict semantic arities based on the observed arities available # for each category. Pre-calculate the necessary associations. category_sem_arities = lex.category_semantic_arities( soft_propagate_roots=True) def iter_expressions_for_arity(arity, max_depth=3): type_request = ontology.types[("e", ) * (arity + 1)] return ontology.iter_expressions(max_depth=max_depth, type_request=type_request) def iter_expressions_for_category(cat): """ Generate candidate semantic expressions for a lexical entry with the given syntactic category. (Forms type requests based on known associations between `cat` and semantic expressions.) """ return itertools.chain.from_iterable( iter_expressions_for_arity(arity) for arity in category_sem_arities[cat]) # Shared dummy variables which is included in candidate semantic forms, to be # replaced by all candidate lexical expressions and evaluated. dummy_vars = { token: l.Variable("F%03i" % i) for i, token in enumerate(tokens) } category_parse_results = {} candidate_queue = None for depth in trange(1, len(tokens) + 1, desc="Depths"): candidate_queue = UniquePriorityQueue(maxsize=queue_limit) token_combs = list(itertools.combinations(tokens, depth)) for token_comb in tqdm(token_combs, desc="Token combinations"): token_syntaxes = [ list(candidate_syntaxes[token].support) for token in token_comb ] for syntax_comb in tqdm(itertools.product(*token_syntaxes), total=np.prod( list(map(len, token_syntaxes))), desc="Syntax combinations"): syntax_weights = [ candidate_syntaxes[token][cat] for token, cat in zip(token_comb, syntax_comb) ] if any(weight == 0 for weight in syntax_weights): continue # Attempt to parse with this joint syntactic assignment, and return the # resulting syntactic parses + sentence-level semantic forms, with # dummy variables in place of where the candidate expressions will go. results = attempt_candidate_parse(lex, token_comb, syntax_comb, sentence, dummy_vars) results = list(results) category_parse_results[syntax_comb] = results for result, apparent_types in results: candidate_exprs = [ list( ontology.iter_expressions( max_depth=3, type_request=apparent_types[token][0])) for token in token_comb ] n_expr_combs = np.prod(list(map(len, candidate_exprs))) for expr_comb in tqdm(itertools.product(*candidate_exprs), total=n_expr_combs, desc="Expressions"): # Compute likelihood of this joint syntax--semantics assignment. likelihood = 0.0 # Swap in semantic values for each token. sentence_semantics = result.label()[0].semantics() for token, token_expr in zip(token_comb, expr_comb): dummy_var = dummy_vars[token] sentence_semantics = sentence_semantics.replace( dummy_var, token_expr) sentence_semantics = sentence_semantics.simplify() # TODO find a better type-resolution solution. The current one will # be super slow -- we can pre-compute most of the relevant types # except for the item swapped in. try: lex.ontology.typecheck(sentence_semantics) except l.TypeException: continue # Compute p(meaning | syntax, sentence, parse) logp = sum( likelihood_fn(token_comb, syntax_comb, expr_comb, sentence_semantics, model) for likelihood_fn in likelihood_fns) likelihood += np.exp(logp) # Add category priors. log_prior = sum( np.log(weight) for weight in syntax_weights) if log_prior == -np.inf or likelihood == 0: # Zero probability. Skip. continue joint_score = log_prior + np.log(likelihood) data = tuple_unordered( [token_comb, syntax_comb, expr_comb]) new_item = (joint_score, data) try: candidate_queue.put_nowait(new_item) except queue.Full: # See if this candidate is better than the worst item. worst = candidate_queue.get() if worst[0] < joint_score: replacement = new_item else: replacement = worst candidate_queue.put_nowait(replacement) if candidate_queue.qsize() > 0: # We have a result. Quit and don't search at higher depth. return candidate_queue, dummy_vars return candidate_queue, dummy_vars
def iter_expressions(self, max_depth, context, function_weights=None, use_unused_constants=False, unused_constants_whitelist=None, unused_constants_blacklist=None): """ Enumerate all legal expressions. Arguments: max_depth: Maximum tree depth to traverse. function_weights: Override for function weights to determine the order in which we consider proposing function application expressions. use_unused_constants: If true, always use unused constants. unused_constants_whitelist: If not None, a set of constants (by name), all newly used constants for the current expression. """ if max_depth == 0: return elif max_depth == 1 and not context.bound_vars: # require some bound variables to generate a valid lexical entry # semantics return unused_constants_whitelist = frozenset(unused_constants_whitelist or []) unused_constants_blacklist = frozenset(unused_constants_blacklist or []) for expr_type in self.EXPR_TYPES: if expr_type == l.ApplicationExpression: # Loop over functions according to their weights. fn_weight_key = (lambda fn: function_weights[fn.name]) if function_weights is not None \ else (lambda fn: fn.weight) fns_sorted = sorted(self.ontology.functions_dict.values(), key=fn_weight_key, reverse=True) if max_depth > 1: for fn in fns_sorted: # If there is a present type request, only consider functions with # the correct return type. # print("\t" * (6 - max_depth), fn.name, fn.return_type, " // request: ", context.semantic_type, context.bound_vars) if context.semantic_type is not None and not fn.return_type.matches( context.semantic_type): continue # Special case: yield fast event queries without recursion. if fn.arity == 1 and fn.arg_types[ 0] == self.types.EVENT_TYPE: yield B.make_application( fn.name, (l.ConstantExpression(l.Variable("e")), )) elif fn.arity == 0: # 0-arity functions are represented in the logic as # `ConstantExpression`s. # print("\t" * (6 - max_depth + 1), "yielding const ", fn.name) yield l.ConstantExpression(l.Variable(fn.name)) else: # print("\t" * (6 - max_depth), fn, fn.arg_types) all_arg_semantic_types = list(fn.arg_types) def product_sub_args(i, ret, blacklist, whitelist): if i >= len(all_arg_semantic_types): yield ret return arg_semantic_type = all_arg_semantic_types[i] sub_context = context.clone_with_semantic_type( arg_semantic_type) results = self.iter_expressions( max_depth=max_depth - 1, context=context, function_weights=function_weights, use_unused_constants=use_unused_constants, unused_constants_whitelist=frozenset( whitelist), unused_constants_blacklist=frozenset( blacklist)) new_blacklist = blacklist for expr in results: new_whitelist = whitelist | { c.name for c in expr.constants() } for sub_expr in product_sub_args( i + 1, ret + (expr, ), new_blacklist, new_whitelist): yield sub_expr new_blacklist = new_blacklist | { c.name for arg in sub_expr for c in arg.constants() } for arg_combs in product_sub_args( 0, tuple(), unused_constants_blacklist, unused_constants_whitelist): candidate = B.make_application( fn.name, arg_combs) valid = self.ontology._valid_application_expr( candidate) # print("\t" * (6 - max_depth + 1), "valid %s? %s" % (candidate, valid)) if valid: yield candidate elif expr_type == l.LambdaExpression and max_depth > 1: if context.semantic_type is None or not isinstance( context.semantic_type, l.ComplexType): continue for num_args in range(1, len(context.semantic_type.flat)): for bound_var_types in itertools.product( self.ontology.observed_argument_types, repeat=num_args): # TODO typecheck with type request bound_vars = list(context.bound_vars) subexpr_bound_vars = [] for new_type in bound_var_types: subexpr_bound_vars.append( next_bound_var(bound_vars + subexpr_bound_vars, new_type)) all_bound_vars = tuple(bound_vars + subexpr_bound_vars) if context.semantic_type is not None: # TODO strong assumption -- assumes that lambda variables are used first subexpr_semantic_type_flat = context.semantic_type.flat[ num_args:] subexpr_semantic_type = self.types[ subexpr_semantic_type_flat] else: subexpr_semantic_type = None # Prepare enumeration context sub_context = context.clone() sub_context.bound_vars = all_bound_vars sub_context.semantic_type = subexpr_semantic_type results = self.iter_expressions( max_depth=max_depth - 1, context=sub_context, function_weights=function_weights, use_unused_constants=use_unused_constants, unused_constants_whitelist= unused_constants_whitelist, unused_constants_blacklist= unused_constants_blacklist) for expr in results: candidate = expr for var in subexpr_bound_vars: candidate = l.LambdaExpression(var, candidate) valid = self.ontology._valid_lambda_expr( candidate, bound_vars) # print("\t" * (6 - max_depth), "valid lambda %s? %s" % (candidate, valid)) if valid: # Assign variable types before returning. extra_types = { bound_var.name: bound_var.type for bound_var in subexpr_bound_vars } try: # TODO make sure variable names are unique before this happens self.ontology.typecheck( candidate, extra_types) except l.InconsistentTypeHierarchyException: pass else: yield candidate elif expr_type == l.IndividualVariableExpression: for bound_var in context.bound_vars: if context.semantic_type and not bound_var.type.matches( context.semantic_type): continue # print("\t" * (6-max_depth), "var %s" % bound_var) yield l.IndividualVariableExpression(bound_var) elif expr_type == l.ConstantExpression: if use_unused_constants: try: for constant in self.ontology.constant_system.iter_new_constants( semantic_type=context.semantic_type, unused_constants_whitelist= unused_constants_whitelist, unused_constants_blacklist= unused_constants_blacklist): yield l.ConstantExpression(constant) except ValueError: pass else: for constant in self.ontology.constants: if context.semantic_type is not None and not constant.type.matches( context.semantic_type): continue yield l.ConstantExpression(constant) elif expr_type == l.FunctionVariableExpression: # NB we don't support enumerating bound variables with function types # right now -- the following only considers yielding fixed functions # from the ontology. for function in self.ontology.functions: # TODO(Jiayuan Mao @ 04/10): check the correctness of the following lie. # I currently skip nullary functions since it has been handled as constant variables # in L2311. if function.arity == 0: continue # Be a little strict here to avoid excessive enumeration -- only # consider emitting functions when the type request specifically # demands a function, not e.g. AnyType if context.semantic_type is None or context.semantic_type == self.types.ANY_TYPE \ or not function.type.matches(context.semantic_type): continue yield l.FunctionVariableExpression( l.Variable(function.name, function.type))
def predict_zero_shot(lex, tokens, candidate_syntaxes, sentence, ontology, model, likelihood_fns, scorer, sentence_meta=None, queue_limit=5, iter_expressions_args=None): """ Make zero-shot predictions of the posterior `p(syntax, meaning | sentence)` for each of `tokens`. Args: lex: tokens: candidate_syntaxes: sentence: ontology: model: likelihood_fns: Collection of likelihood functions `p(meanings | syntaxes, sentence, model)` used to score candidate meaning--syntax settings for a subset of `tokens`. Each function should accept arguments `(tokens, candidate_categories, candidate_meanings, candidate_semantic_parse, model)`, where `tokens` are assigned specific categories given in `candidate_categories` and specific meanings given in `candidate_meanings`, yielding a single semantic analysis of the sentence `candidate_semantic_parse`. The function should return a log-likelihood `p(candidate_meanings | candidate_syntaxes, sentence, model)`. Returns: queues: A dictionary mapping each query token to a ranked sequence of candidates of the form `(logprob, (tokens, candidate_categories, candidate_semantics))`, describing a nonzero-probability novel mapping of a subset `tokens` to syntactic categories `candidate_categories` and meanings `candidate_semantics`. The log-probability value given is `p(meanings, syntaxes | sentence, model)`, under the relevant provided meaning likelihoods and the lexicon's distribution over syntactic forms. dummy_vars: TODO """ get_arity = (lex.ontology and lex.ontology.get_expr_arity) \ or get_semantic_arity iter_expressions_args = iter_expressions_args or {} # We will restrict semantic arities based on the observed arities available # for each category. Pre-calculate the necessary associations. category_sem_arities = lex.category_semantic_arities() def iter_expressions_for_arity(arity, max_depth=3, blacklist=None, **kwargs): semantic_type = ontology.types[("e",) * (arity + 1)] passed_kwargs = dict() # First include global iter_expressions kwargs passed_kwargs.update(iter_expressions_args) # Now set instance-specific arguments assert 'semantic_type' not in passed_kwargs passed_kwargs['semantic_type'] = semantic_type assert 'unused_constants_blacklist' not in passed_kwargs passed_kwargs['unused_constants_blacklist'] = blacklist passed_kwargs.setdefault('max_depth', max_depth) for key, value in kwargs.items(): passed_kwargs.setdefault(key, value) return ontology.iter_expressions(**passed_kwargs) def iter_expressions_for_category(cat, blacklist=None): """ Generate candidate semantic expressions for a lexical entry with the given syntactic category. (Forms type requests based on known associations between `cat` and semantic expressions.) """ return itertools.chain.from_iterable( iter_expressions_for_arity(arity, blacklist=blacklist, syntactic_type=cat) for arity in category_sem_arities[cat]) # for expr_comb in tqdm(itertools.product(*candidate_exprs), def product_candidate_exprs(syntax_comb): # NB(Jiayuan Mao @ 04/11): accelerate the iteration. # TODO(Jiayuan Mao @ 04/11): do we need this? maybe the cache of iter_expressions can automatically handle this. if (len(syntax_comb) == 1) or (not iter_expressions_args.get('use_unused_constants', False)): candidate_exprs = tuple(list(iter_expressions_for_category(cat)) for cat in syntax_comb) return list(itertools.product(*candidate_exprs)) candidate_exprs = list() black_list = set() for i, cat in enumerate(syntax_comb): this_candidate_exprs = list(iter_expressions_for_category(cat, frozenset(black_list))) black_list |= {c.name for expr in this_candidate_exprs for c in expr.constants()} candidate_exprs.append(this_candidate_exprs) return list(itertools.product(*candidate_exprs)) # Shared dummy variables which is included in candidate semantic forms, to be # replaced by all candidate lexical expressions and evaluated. dummy_vars = {token: l.Variable("F%03i" % i) for i, token in enumerate(tokens)} category_parse_results = {} candidate_queue = None for depth in trange(1, len(tokens) + 1, desc="Depths"): candidate_queue = UniquePriorityQueue(maxsize=queue_limit) token_combs = list(itertools.combinations(tokens, depth)) # for token_comb in tqdm(token_combs, desc="Token combinations"): for token_comb in token_combs: # TODO(Jiayuan Mao @ 04/10): if there are multiple words to be induced at the same time, there will be a bug for use_unused_concepts. token_syntaxes = [list(candidate_syntaxes[token].support) for token in token_comb] for syntax_comb in tqdm(itertools.product(*token_syntaxes), total=np.prod(list(map(len, token_syntaxes))), desc="Syntax combinations"): syntax_weights = [candidate_syntaxes[token][cat] for token, cat in zip(token_comb, syntax_comb)] if any(weight == 0 for weight in syntax_weights): continue # Attempt to parse with this joint syntactic assignment, and return the # resulting syntactic parses + sentence-level semantic forms, with # dummy variables in place of where the candidate expressions will go. results = attempt_candidate_parse(lex, token_comb, syntax_comb, sentence, scorer, dummy_vars, sentence_meta=sentence_meta) category_parse_results[syntax_comb] = results # Now enumerate semantic forms. # candidate_exprs = tuple(list(iter_expressions_for_category(cat)) # for cat in syntax_comb) # n_expr_combs = np.prod(list(map(len, candidate_exprs))) all_expr_combs = product_candidate_exprs(syntax_comb) for expr_comb in tqdm(all_expr_combs, desc="Expressions"): # Compute likelihood of this joint syntax--semantics assignment. # TODO(Jiayuan Mao @ 04/08): += logp? or += p? # Probably we can remove the following line? likelihood = 0.0 for result in results: # Swap in semantic values for each token. sentence_semantics = result.label()[0].semantics() for token, token_expr in zip(token_comb, expr_comb): dummy_var = dummy_vars[token] sentence_semantics = sentence_semantics.replace(dummy_var, token_expr) sentence_semantics = sentence_semantics.simplify() try: lex.ontology.typecheck(sentence_semantics) except l.TypeException as exc: continue # print('SUCCESS: ' + '; '.join([f'{t} => {str(s)} [{str(e)}]' for t, s, e in zip(token_comb, syntax_comb, expr_comb)]), sentence_semantics, sep='\n', end='\n' + '-'*120 + '\n') # Compute p(meaning | syntax, sentence, parse) logp = sum(likelihood_fn(token_comb, syntax_comb, expr_comb, sentence_semantics, model) for likelihood_fn in likelihood_fns) likelihood += np.exp(logp) # Add category priors. log_prior = sum(T.log(weight) for weight in syntax_weights) joint_score = log_prior + logp if joint_score == -np.inf: # Zero probability. Skip. continue data = tuple_unordered([token_comb, syntax_comb, expr_comb]) new_item = (joint_score, data) try: candidate_queue.put_nowait(new_item) except queue.Full: # See if this candidate is better than the worst item. worst = candidate_queue.get() if worst[0] < joint_score: replacement = new_item else: replacement = worst candidate_queue.put_nowait(replacement) if candidate_queue.qsize() > 0: # We have a result. Quit and don't search at higher depth. return candidate_queue, dummy_vars return candidate_queue, dummy_vars