def combine(self, function, argument): if self.can_combine(function, argument): categ = FunctionalCategory( function.categ().res().res(), argument.categ().arg(), argument.categ().dir() ) # TODO type-inference fsem, asem = function.semantics(), argument.semantics() new_arg = l.ApplicationExpression(asem, l.VariableExpression(fsem.variable)).simplify() new_term = l.ApplicationExpression(fsem.term, new_arg).simplify() semantics = l.LambdaExpression(fsem.variable, new_term) yield categ, semantics
def combine(self, function, argument): if not (function.categ().is_function() and argument.categ().is_function()): return if function.categ().dir().can_compose() and argument.categ().dir().can_compose(): subs = function.categ().arg().can_unify(argument.categ().res()) if subs is not None: categ = FunctionalCategory( function.categ().res().substitute(subs), argument.categ().arg().substitute(subs), argument.categ().dir()) fsem, asem = function.semantics(), argument.semantics() if fsem is not None and asem is not None: semantics = l.LambdaExpression(asem.variable, l.ApplicationExpression(fsem, asem.term).simplify()) else: semantics = None yield categ, semantics
def combine(self, function, arg): if not (function.categ().is_primitive() and arg.categ().is_function() and arg.categ().res().is_function()): return # Type-raising matches only the innermost application. arg = innermostFunction(arg.categ()) subs = function.categ().can_unify(arg.arg()) if subs is not None: xcat = arg.res().substitute(subs) categ = FunctionalCategory( xcat, FunctionalCategory(xcat, function.categ(), arg.dir()), -(arg.dir())) # compute semantics semantics = None if function.semantics() is not None: core = deepcopy(function.semantics()) parent = None while isinstance(core, l.LambdaExpression): parent = core core = core.term var = l.Variable("F") while var in core.free(): var = l.unique_variable(pattern=var) core = l.ApplicationExpression( l.FunctionVariableExpression(var), core) if parent is not None: parent.term = core else: semantics = core semantics = l.LambdaExpression(var, semantics) yield categ, semantics
def iter_expressions(self, max_depth, context, function_weights=None, use_unused_constants=False, unused_constants_whitelist=None, unused_constants_blacklist=None): """ Enumerate all legal expressions. Arguments: max_depth: Maximum tree depth to traverse. function_weights: Override for function weights to determine the order in which we consider proposing function application expressions. use_unused_constants: If true, always use unused constants. unused_constants_whitelist: If not None, a set of constants (by name), all newly used constants for the current expression. """ if max_depth == 0: return elif max_depth == 1 and not context.bound_vars: # require some bound variables to generate a valid lexical entry # semantics return unused_constants_whitelist = frozenset(unused_constants_whitelist or []) unused_constants_blacklist = frozenset(unused_constants_blacklist or []) for expr_type in self.EXPR_TYPES: if expr_type == l.ApplicationExpression: # Loop over functions according to their weights. fn_weight_key = (lambda fn: function_weights[fn.name]) if function_weights is not None \ else (lambda fn: fn.weight) fns_sorted = sorted(self.ontology.functions_dict.values(), key=fn_weight_key, reverse=True) if max_depth > 1: for fn in fns_sorted: # If there is a present type request, only consider functions with # the correct return type. # print("\t" * (6 - max_depth), fn.name, fn.return_type, " // request: ", context.semantic_type, context.bound_vars) if context.semantic_type is not None and not fn.return_type.matches( context.semantic_type): continue # Special case: yield fast event queries without recursion. if fn.arity == 1 and fn.arg_types[ 0] == self.types.EVENT_TYPE: yield B.make_application( fn.name, (l.ConstantExpression(l.Variable("e")), )) elif fn.arity == 0: # 0-arity functions are represented in the logic as # `ConstantExpression`s. # print("\t" * (6 - max_depth + 1), "yielding const ", fn.name) yield l.ConstantExpression(l.Variable(fn.name)) else: # print("\t" * (6 - max_depth), fn, fn.arg_types) all_arg_semantic_types = list(fn.arg_types) def product_sub_args(i, ret, blacklist, whitelist): if i >= len(all_arg_semantic_types): yield ret return arg_semantic_type = all_arg_semantic_types[i] sub_context = context.clone_with_semantic_type( arg_semantic_type) results = self.iter_expressions( max_depth=max_depth - 1, context=context, function_weights=function_weights, use_unused_constants=use_unused_constants, unused_constants_whitelist=frozenset( whitelist), unused_constants_blacklist=frozenset( blacklist)) new_blacklist = blacklist for expr in results: new_whitelist = whitelist | { c.name for c in expr.constants() } for sub_expr in product_sub_args( i + 1, ret + (expr, ), new_blacklist, new_whitelist): yield sub_expr new_blacklist = new_blacklist | { c.name for arg in sub_expr for c in arg.constants() } for arg_combs in product_sub_args( 0, tuple(), unused_constants_blacklist, unused_constants_whitelist): candidate = B.make_application( fn.name, arg_combs) valid = self.ontology._valid_application_expr( candidate) # print("\t" * (6 - max_depth + 1), "valid %s? %s" % (candidate, valid)) if valid: yield candidate elif expr_type == l.LambdaExpression and max_depth > 1: if context.semantic_type is None or not isinstance( context.semantic_type, l.ComplexType): continue for num_args in range(1, len(context.semantic_type.flat)): for bound_var_types in itertools.product( self.ontology.observed_argument_types, repeat=num_args): # TODO typecheck with type request bound_vars = list(context.bound_vars) subexpr_bound_vars = [] for new_type in bound_var_types: subexpr_bound_vars.append( next_bound_var(bound_vars + subexpr_bound_vars, new_type)) all_bound_vars = tuple(bound_vars + subexpr_bound_vars) if context.semantic_type is not None: # TODO strong assumption -- assumes that lambda variables are used first subexpr_semantic_type_flat = context.semantic_type.flat[ num_args:] subexpr_semantic_type = self.types[ subexpr_semantic_type_flat] else: subexpr_semantic_type = None # Prepare enumeration context sub_context = context.clone() sub_context.bound_vars = all_bound_vars sub_context.semantic_type = subexpr_semantic_type results = self.iter_expressions( max_depth=max_depth - 1, context=sub_context, function_weights=function_weights, use_unused_constants=use_unused_constants, unused_constants_whitelist= unused_constants_whitelist, unused_constants_blacklist= unused_constants_blacklist) for expr in results: candidate = expr for var in subexpr_bound_vars: candidate = l.LambdaExpression(var, candidate) valid = self.ontology._valid_lambda_expr( candidate, bound_vars) # print("\t" * (6 - max_depth), "valid lambda %s? %s" % (candidate, valid)) if valid: # Assign variable types before returning. extra_types = { bound_var.name: bound_var.type for bound_var in subexpr_bound_vars } try: # TODO make sure variable names are unique before this happens self.ontology.typecheck( candidate, extra_types) except l.InconsistentTypeHierarchyException: pass else: yield candidate elif expr_type == l.IndividualVariableExpression: for bound_var in context.bound_vars: if context.semantic_type and not bound_var.type.matches( context.semantic_type): continue # print("\t" * (6-max_depth), "var %s" % bound_var) yield l.IndividualVariableExpression(bound_var) elif expr_type == l.ConstantExpression: if use_unused_constants: try: for constant in self.ontology.constant_system.iter_new_constants( semantic_type=context.semantic_type, unused_constants_whitelist= unused_constants_whitelist, unused_constants_blacklist= unused_constants_blacklist): yield l.ConstantExpression(constant) except ValueError: pass else: for constant in self.ontology.constants: if context.semantic_type is not None and not constant.type.matches( context.semantic_type): continue yield l.ConstantExpression(constant) elif expr_type == l.FunctionVariableExpression: # NB we don't support enumerating bound variables with function types # right now -- the following only considers yielding fixed functions # from the ontology. for function in self.ontology.functions: # TODO(Jiayuan Mao @ 04/10): check the correctness of the following lie. # I currently skip nullary functions since it has been handled as constant variables # in L2311. if function.arity == 0: continue # Be a little strict here to avoid excessive enumeration -- only # consider emitting functions when the type request specifically # demands a function, not e.g. AnyType if context.semantic_type is None or context.semantic_type == self.types.ANY_TYPE \ or not function.type.matches(context.semantic_type): continue yield l.FunctionVariableExpression( l.Variable(function.name, function.type))