コード例 #1
0
    def forward(self, parse, sentence_meta=None):
        if sentence_meta is None or sentence_meta.get("frame_str",
                                                      None) is None:
            raise ValueError(
                "FrameSemanticsScorer requires a sentence_meta key frame_str")

        frame = sentence_meta["frame_str"]
        try:
            frame_idx = self.frame_to_idx[frame]
        except KeyError:
            raise ValueError("Unknown frame string %s" % frame)

        ret = self.frame_dist(self.frame_to_idx[frame])
        predicate_logps = F.log_softmax(ret)

        score = T.zeros(())
        try:
            root_verb = next(tok for _, tok in parse.pos()
                             if str(tok.categ()) in self.root_types)
        except:
            root_verb = None
            return score

        for predicate in root_verb.semantics().predicates():
            score += predicate_logps[self.predicate_to_idx[predicate]]

        for predicate in root_verb.semantics().constants():
            predicate = l.Variable(predicate.name)
            score += predicate_logps[self.predicate_to_idx[predicate]]

        return score
コード例 #2
0
    def __init__(self, lexicon, frames, root_types=(r"(S\N)", r"((S\N)/N)")):
        """
    Args:
      lexicon:
      frames: Collection of all possible frame strings
      root_types: CCG syntactic type strings of lexical entries for which we
        are collecting frames
    """
        super().__init__(lexicon)

        self.frames = frames
        self.frame_to_idx = {
            frame: T.tensor(idx, requires_grad=False)
            for idx, frame in enumerate(sorted(self.frames))
        }

        self.root_types = set(root_types)
        self.gradients_disabled = False

        ontology = self._lexicon.ontology
        self.predicates = [
            l.Variable(val.name)
            for val in ontology.functions + ontology.constants
        ]
        self.predicate_to_idx = {
            pred: idx
            for idx, pred in enumerate(sorted(self.predicates))
        }

        # Represent unnormalized frame distributions as an embedding layer
        self.frame_dist = nn.Embedding(len(self.frames), len(self.predicates))
        nn.init.zeros_(self.frame_dist.weight)
コード例 #3
0
ファイル: test_lexicon.py プロジェクト: thuqinyj16/pyccg
def test_attempt_candidate_parse():
    """
  Find parse candidates even when the parse requires composition.
  """
    lex = Lexicon.fromstring(r"""
  :- S, N

  gives => S\N/N/N {\o x y.give(x, y, o)}
  John => N {\x.John(x)}
  Mark => N {\x.Mark(x)}
  it => N {\x.T}
  """,
                             include_semantics=True)
    # TODO this doesn't actually require composition .. get one which does

    cand_category = lex.parse_category(r"S\N/N/N")
    cand_expressions = [l.Expression.fromstring(r"\o x y.give(x,y,o)")]
    dummy_vars = {"sends": l.Variable("F000")}
    results = attempt_candidate_parse(lex, ["sends"], [cand_category],
                                      "John sends Mark it".split(), dummy_vars)

    ok_(len(list(results)) > 0)
コード例 #4
0
ファイル: combinator.py プロジェクト: cogsci2020verb/pyccg
    def combine(self, function, arg):
        if not (function.categ().is_primitive() and arg.categ().is_function()
                and arg.categ().res().is_function()):
            return

        # Type-raising matches only the innermost application.
        arg = innermostFunction(arg.categ())

        subs = function.categ().can_unify(arg.arg())
        if subs is not None:
            xcat = arg.res().substitute(subs)
            categ = FunctionalCategory(
                xcat, FunctionalCategory(xcat, function.categ(), arg.dir()),
                -(arg.dir()))

            # compute semantics
            semantics = None
            if function.semantics() is not None:
                core = deepcopy(function.semantics())
                parent = None
                while isinstance(core, l.LambdaExpression):
                    parent = core
                    core = core.term

                var = l.Variable("F")
                while var in core.free():
                    var = l.unique_variable(pattern=var)
                core = l.ApplicationExpression(
                    l.FunctionVariableExpression(var), core)

                if parent is not None:
                    parent.term = core
                else:
                    semantics = core

                semantics = l.LambdaExpression(var, semantics)

            yield categ, semantics
コード例 #5
0
ファイル: lexicon.py プロジェクト: liqing-ustc/dreamcoder
def predict_zero_shot(lex,
                      tokens,
                      candidate_syntaxes,
                      sentence,
                      ontology,
                      model,
                      likelihood_fns,
                      queue_limit=5):
    """
  Make zero-shot predictions of the posterior `p(syntax, meaning | sentence)`
  for each of `tokens`.

  Args:
    lex:
    tokens:
    candidate_syntaxes:
    sentence:
    ontology:
    model:
    likelihood_fns: Collection of likelihood functions
      `p(meanings | syntaxes, sentence, model)` used to score candidate
      meaning--syntax settings for a subset of `tokens`.  Each function should
      accept arguments `(tokens, candidate_categories, candidate_meanings,
      candidate_semantic_parse, model)`, where `tokens` are assigned specific
      categories given in `candidate_categories` and specific meanings given in
      `candidate_meanings`, yielding a single semantic analysis of the sentence
      `candidate_semantic_parse`. The function should return a log-likelihood
      `p(candidate_meanings | candidate_syntaxes, sentence, model)`.

  Returns:
    queues: A dictionary mapping each query token to a ranked sequence of
      candidates of the form
      `(logprob, (tokens, candidate_categories, candidate_semantics))`,
      describing a nonzero-probability novel mapping of a subset `tokens` to
      syntactic categories `candidate_categories` and meanings
      `candidate_semantics`. The log-probability value given is
      `p(meanings, syntaxes | sentence, model)`, under the relevant provided
      meaning likelihoods and the lexicon's distribution over syntactic forms.
    dummy_vars: TODO
  """

    get_arity = (lex.ontology and lex.ontology.get_expr_arity) \
        or get_semantic_arity

    # We will restrict semantic arities based on the observed arities available
    # for each category. Pre-calculate the necessary associations.
    category_sem_arities = lex.category_semantic_arities(
        soft_propagate_roots=True)

    def iter_expressions_for_arity(arity, max_depth=3):
        type_request = ontology.types[("e", ) * (arity + 1)]
        return ontology.iter_expressions(max_depth=max_depth,
                                         type_request=type_request)

    def iter_expressions_for_category(cat):
        """
    Generate candidate semantic expressions for a lexical entry with the given
    syntactic category. (Forms type requests based on known associations
    between `cat` and semantic expressions.)
    """
        return itertools.chain.from_iterable(
            iter_expressions_for_arity(arity)
            for arity in category_sem_arities[cat])

    # Shared dummy variables which is included in candidate semantic forms, to be
    # replaced by all candidate lexical expressions and evaluated.
    dummy_vars = {
        token: l.Variable("F%03i" % i)
        for i, token in enumerate(tokens)
    }

    category_parse_results = {}
    candidate_queue = None
    for depth in trange(1, len(tokens) + 1, desc="Depths"):
        candidate_queue = UniquePriorityQueue(maxsize=queue_limit)

        token_combs = list(itertools.combinations(tokens, depth))
        for token_comb in tqdm(token_combs, desc="Token combinations"):
            token_syntaxes = [
                list(candidate_syntaxes[token].support) for token in token_comb
            ]
            for syntax_comb in tqdm(itertools.product(*token_syntaxes),
                                    total=np.prod(
                                        list(map(len, token_syntaxes))),
                                    desc="Syntax combinations"):
                syntax_weights = [
                    candidate_syntaxes[token][cat]
                    for token, cat in zip(token_comb, syntax_comb)
                ]
                if any(weight == 0 for weight in syntax_weights):
                    continue

                # Attempt to parse with this joint syntactic assignment, and return the
                # resulting syntactic parses + sentence-level semantic forms, with
                # dummy variables in place of where the candidate expressions will go.
                results = attempt_candidate_parse(lex, token_comb, syntax_comb,
                                                  sentence, dummy_vars)
                results = list(results)
                category_parse_results[syntax_comb] = results

                for result, apparent_types in results:
                    candidate_exprs = [
                        list(
                            ontology.iter_expressions(
                                max_depth=3,
                                type_request=apparent_types[token][0]))
                        for token in token_comb
                    ]

                    n_expr_combs = np.prod(list(map(len, candidate_exprs)))
                    for expr_comb in tqdm(itertools.product(*candidate_exprs),
                                          total=n_expr_combs,
                                          desc="Expressions"):
                        # Compute likelihood of this joint syntax--semantics assignment.
                        likelihood = 0.0

                        # Swap in semantic values for each token.
                        sentence_semantics = result.label()[0].semantics()
                        for token, token_expr in zip(token_comb, expr_comb):
                            dummy_var = dummy_vars[token]
                            sentence_semantics = sentence_semantics.replace(
                                dummy_var, token_expr)
                        sentence_semantics = sentence_semantics.simplify()

                        # TODO find a better type-resolution solution. The current one will
                        # be super slow -- we can pre-compute most of the relevant types
                        # except for the item swapped in.
                        try:
                            lex.ontology.typecheck(sentence_semantics)
                        except l.TypeException:
                            continue

                        # Compute p(meaning | syntax, sentence, parse)
                        logp = sum(
                            likelihood_fn(token_comb, syntax_comb, expr_comb,
                                          sentence_semantics, model)
                            for likelihood_fn in likelihood_fns)
                        likelihood += np.exp(logp)

                        # Add category priors.
                        log_prior = sum(
                            np.log(weight) for weight in syntax_weights)
                        if log_prior == -np.inf or likelihood == 0:
                            # Zero probability. Skip.
                            continue

                        joint_score = log_prior + np.log(likelihood)

                        data = tuple_unordered(
                            [token_comb, syntax_comb, expr_comb])
                        new_item = (joint_score, data)
                        try:
                            candidate_queue.put_nowait(new_item)
                        except queue.Full:
                            # See if this candidate is better than the worst item.
                            worst = candidate_queue.get()
                            if worst[0] < joint_score:
                                replacement = new_item
                            else:
                                replacement = worst

                            candidate_queue.put_nowait(replacement)

        if candidate_queue.qsize() > 0:
            # We have a result. Quit and don't search at higher depth.
            return candidate_queue, dummy_vars

    return candidate_queue, dummy_vars
コード例 #6
0
ファイル: iterators.py プロジェクト: cogsci2020verb/pyccg
    def iter_expressions(self,
                         max_depth,
                         context,
                         function_weights=None,
                         use_unused_constants=False,
                         unused_constants_whitelist=None,
                         unused_constants_blacklist=None):
        """
    Enumerate all legal expressions.

    Arguments:
      max_depth: Maximum tree depth to traverse.
      function_weights: Override for function weights to determine the order in
        which we consider proposing function application expressions.
      use_unused_constants: If true, always use unused constants.
      unused_constants_whitelist: If not None, a set of constants (by name),
        all newly used constants for the current expression.
    """
        if max_depth == 0:
            return
        elif max_depth == 1 and not context.bound_vars:
            # require some bound variables to generate a valid lexical entry
            # semantics
            return

        unused_constants_whitelist = frozenset(unused_constants_whitelist
                                               or [])
        unused_constants_blacklist = frozenset(unused_constants_blacklist
                                               or [])

        for expr_type in self.EXPR_TYPES:
            if expr_type == l.ApplicationExpression:
                # Loop over functions according to their weights.
                fn_weight_key = (lambda fn: function_weights[fn.name]) if function_weights is not None \
                                else (lambda fn: fn.weight)
                fns_sorted = sorted(self.ontology.functions_dict.values(),
                                    key=fn_weight_key,
                                    reverse=True)

                if max_depth > 1:
                    for fn in fns_sorted:
                        # If there is a present type request, only consider functions with
                        # the correct return type.
                        # print("\t" * (6 - max_depth), fn.name, fn.return_type, " // request: ", context.semantic_type, context.bound_vars)
                        if context.semantic_type is not None and not fn.return_type.matches(
                                context.semantic_type):
                            continue

                        # Special case: yield fast event queries without recursion.
                        if fn.arity == 1 and fn.arg_types[
                                0] == self.types.EVENT_TYPE:
                            yield B.make_application(
                                fn.name,
                                (l.ConstantExpression(l.Variable("e")), ))
                        elif fn.arity == 0:
                            # 0-arity functions are represented in the logic as
                            # `ConstantExpression`s.
                            # print("\t" * (6 - max_depth + 1), "yielding const ", fn.name)
                            yield l.ConstantExpression(l.Variable(fn.name))
                        else:
                            # print("\t" * (6 - max_depth), fn, fn.arg_types)
                            all_arg_semantic_types = list(fn.arg_types)

                            def product_sub_args(i, ret, blacklist, whitelist):
                                if i >= len(all_arg_semantic_types):
                                    yield ret
                                    return

                                arg_semantic_type = all_arg_semantic_types[i]
                                sub_context = context.clone_with_semantic_type(
                                    arg_semantic_type)
                                results = self.iter_expressions(
                                    max_depth=max_depth - 1,
                                    context=context,
                                    function_weights=function_weights,
                                    use_unused_constants=use_unused_constants,
                                    unused_constants_whitelist=frozenset(
                                        whitelist),
                                    unused_constants_blacklist=frozenset(
                                        blacklist))

                                new_blacklist = blacklist
                                for expr in results:
                                    new_whitelist = whitelist | {
                                        c.name
                                        for c in expr.constants()
                                    }
                                    for sub_expr in product_sub_args(
                                            i + 1, ret + (expr, ),
                                            new_blacklist, new_whitelist):
                                        yield sub_expr
                                        new_blacklist = new_blacklist | {
                                            c.name
                                            for arg in sub_expr
                                            for c in arg.constants()
                                        }

                            for arg_combs in product_sub_args(
                                    0, tuple(), unused_constants_blacklist,
                                    unused_constants_whitelist):
                                candidate = B.make_application(
                                    fn.name, arg_combs)
                                valid = self.ontology._valid_application_expr(
                                    candidate)
                                # print("\t" * (6 - max_depth + 1), "valid %s? %s" % (candidate, valid))
                                if valid:
                                    yield candidate
            elif expr_type == l.LambdaExpression and max_depth > 1:
                if context.semantic_type is None or not isinstance(
                        context.semantic_type, l.ComplexType):
                    continue

                for num_args in range(1, len(context.semantic_type.flat)):
                    for bound_var_types in itertools.product(
                            self.ontology.observed_argument_types,
                            repeat=num_args):
                        # TODO typecheck with type request

                        bound_vars = list(context.bound_vars)
                        subexpr_bound_vars = []
                        for new_type in bound_var_types:
                            subexpr_bound_vars.append(
                                next_bound_var(bound_vars + subexpr_bound_vars,
                                               new_type))
                        all_bound_vars = tuple(bound_vars + subexpr_bound_vars)

                        if context.semantic_type is not None:
                            # TODO strong assumption -- assumes that lambda variables are used first
                            subexpr_semantic_type_flat = context.semantic_type.flat[
                                num_args:]
                            subexpr_semantic_type = self.types[
                                subexpr_semantic_type_flat]
                        else:
                            subexpr_semantic_type = None

                        # Prepare enumeration context
                        sub_context = context.clone()
                        sub_context.bound_vars = all_bound_vars
                        sub_context.semantic_type = subexpr_semantic_type

                        results = self.iter_expressions(
                            max_depth=max_depth - 1,
                            context=sub_context,
                            function_weights=function_weights,
                            use_unused_constants=use_unused_constants,
                            unused_constants_whitelist=
                            unused_constants_whitelist,
                            unused_constants_blacklist=
                            unused_constants_blacklist)

                        for expr in results:
                            candidate = expr
                            for var in subexpr_bound_vars:
                                candidate = l.LambdaExpression(var, candidate)
                            valid = self.ontology._valid_lambda_expr(
                                candidate, bound_vars)
                            # print("\t" * (6 - max_depth), "valid lambda %s? %s" % (candidate, valid))
                            if valid:
                                # Assign variable types before returning.
                                extra_types = {
                                    bound_var.name: bound_var.type
                                    for bound_var in subexpr_bound_vars
                                }

                                try:
                                    # TODO make sure variable names are unique before this happens
                                    self.ontology.typecheck(
                                        candidate, extra_types)
                                except l.InconsistentTypeHierarchyException:
                                    pass
                                else:
                                    yield candidate
            elif expr_type == l.IndividualVariableExpression:
                for bound_var in context.bound_vars:
                    if context.semantic_type and not bound_var.type.matches(
                            context.semantic_type):
                        continue

                    # print("\t" * (6-max_depth), "var %s" % bound_var)

                    yield l.IndividualVariableExpression(bound_var)
            elif expr_type == l.ConstantExpression:
                if use_unused_constants:
                    try:
                        for constant in self.ontology.constant_system.iter_new_constants(
                                semantic_type=context.semantic_type,
                                unused_constants_whitelist=
                                unused_constants_whitelist,
                                unused_constants_blacklist=
                                unused_constants_blacklist):

                            yield l.ConstantExpression(constant)
                    except ValueError:
                        pass
                else:
                    for constant in self.ontology.constants:
                        if context.semantic_type is not None and not constant.type.matches(
                                context.semantic_type):
                            continue

                        yield l.ConstantExpression(constant)
            elif expr_type == l.FunctionVariableExpression:
                # NB we don't support enumerating bound variables with function types
                # right now -- the following only considers yielding fixed functions
                # from the ontology.
                for function in self.ontology.functions:
                    # TODO(Jiayuan Mao @ 04/10): check the correctness of the following lie.
                    # I currently skip nullary functions since it has been handled as constant variables
                    # in L2311.
                    if function.arity == 0:
                        continue

                    # Be a little strict here to avoid excessive enumeration -- only
                    # consider emitting functions when the type request specifically
                    # demands a function, not e.g. AnyType
                    if context.semantic_type is None or context.semantic_type == self.types.ANY_TYPE \
                        or not function.type.matches(context.semantic_type):
                        continue

                    yield l.FunctionVariableExpression(
                        l.Variable(function.name, function.type))
コード例 #7
0
ファイル: lexicon.py プロジェクト: cogsci2020verb/pyccg
def predict_zero_shot(lex, tokens, candidate_syntaxes, sentence, ontology,
                      model, likelihood_fns, scorer,
                      sentence_meta=None,
                      queue_limit=5,
                      iter_expressions_args=None):
  """
  Make zero-shot predictions of the posterior `p(syntax, meaning | sentence)`
  for each of `tokens`.

  Args:
    lex:
    tokens:
    candidate_syntaxes:
    sentence:
    ontology:
    model:
    likelihood_fns: Collection of likelihood functions
      `p(meanings | syntaxes, sentence, model)` used to score candidate
      meaning--syntax settings for a subset of `tokens`.  Each function should
      accept arguments `(tokens, candidate_categories, candidate_meanings,
      candidate_semantic_parse, model)`, where `tokens` are assigned specific
      categories given in `candidate_categories` and specific meanings given in
      `candidate_meanings`, yielding a single semantic analysis of the sentence
      `candidate_semantic_parse`. The function should return a log-likelihood
      `p(candidate_meanings | candidate_syntaxes, sentence, model)`.

  Returns:
    queues: A dictionary mapping each query token to a ranked sequence of
      candidates of the form
      `(logprob, (tokens, candidate_categories, candidate_semantics))`,
      describing a nonzero-probability novel mapping of a subset `tokens` to
      syntactic categories `candidate_categories` and meanings
      `candidate_semantics`. The log-probability value given is
      `p(meanings, syntaxes | sentence, model)`, under the relevant provided
      meaning likelihoods and the lexicon's distribution over syntactic forms.
    dummy_vars: TODO
  """

  get_arity = (lex.ontology and lex.ontology.get_expr_arity) \
      or get_semantic_arity
  iter_expressions_args = iter_expressions_args or {}

  # We will restrict semantic arities based on the observed arities available
  # for each category. Pre-calculate the necessary associations.
  category_sem_arities = lex.category_semantic_arities()

  def iter_expressions_for_arity(arity, max_depth=3, blacklist=None,
                                 **kwargs):
    semantic_type = ontology.types[("e",) * (arity + 1)]

    passed_kwargs = dict()
    # First include global iter_expressions kwargs
    passed_kwargs.update(iter_expressions_args)

    # Now set instance-specific arguments
    assert 'semantic_type' not in passed_kwargs
    passed_kwargs['semantic_type'] = semantic_type
    assert 'unused_constants_blacklist' not in passed_kwargs
    passed_kwargs['unused_constants_blacklist'] = blacklist
    passed_kwargs.setdefault('max_depth', max_depth)
    for key, value in kwargs.items():
      passed_kwargs.setdefault(key, value)

    return ontology.iter_expressions(**passed_kwargs)

  def iter_expressions_for_category(cat, blacklist=None):
    """
    Generate candidate semantic expressions for a lexical entry with the given
    syntactic category. (Forms type requests based on known associations
    between `cat` and semantic expressions.)
    """
    return itertools.chain.from_iterable(
        iter_expressions_for_arity(arity, blacklist=blacklist, syntactic_type=cat)
        for arity in category_sem_arities[cat])

  # for expr_comb in tqdm(itertools.product(*candidate_exprs),
  def product_candidate_exprs(syntax_comb):
    # NB(Jiayuan Mao @ 04/11): accelerate the iteration.
    # TODO(Jiayuan Mao @ 04/11): do we need this? maybe the cache of iter_expressions can automatically handle this.
    if (len(syntax_comb) == 1) or (not iter_expressions_args.get('use_unused_constants', False)):
      candidate_exprs = tuple(list(iter_expressions_for_category(cat)) for cat in syntax_comb)
      return list(itertools.product(*candidate_exprs))

    candidate_exprs = list()
    black_list = set()
    for i, cat in enumerate(syntax_comb):
      this_candidate_exprs = list(iter_expressions_for_category(cat, frozenset(black_list)))
      black_list |= {c.name for expr in this_candidate_exprs for c in expr.constants()}
      candidate_exprs.append(this_candidate_exprs)
    return list(itertools.product(*candidate_exprs))

  # Shared dummy variables which is included in candidate semantic forms, to be
  # replaced by all candidate lexical expressions and evaluated.
  dummy_vars = {token: l.Variable("F%03i" % i) for i, token in enumerate(tokens)}

  category_parse_results = {}
  candidate_queue = None
  for depth in trange(1, len(tokens) + 1, desc="Depths"):
    candidate_queue = UniquePriorityQueue(maxsize=queue_limit)

    token_combs = list(itertools.combinations(tokens, depth))
    # for token_comb in tqdm(token_combs, desc="Token combinations"):
    for token_comb in token_combs:
      # TODO(Jiayuan Mao @ 04/10): if there are multiple words to be induced at the same time, there will be a bug for use_unused_concepts.
      token_syntaxes = [list(candidate_syntaxes[token].support) for token in token_comb]
      for syntax_comb in tqdm(itertools.product(*token_syntaxes),
                              total=np.prod(list(map(len, token_syntaxes))),
                              desc="Syntax combinations"):
        syntax_weights = [candidate_syntaxes[token][cat] for token, cat in zip(token_comb, syntax_comb)]
        if any(weight == 0 for weight in syntax_weights):
          continue

        # Attempt to parse with this joint syntactic assignment, and return the
        # resulting syntactic parses + sentence-level semantic forms, with
        # dummy variables in place of where the candidate expressions will go.
        results = attempt_candidate_parse(lex, token_comb,
                                          syntax_comb,
                                          sentence,
                                          scorer,
                                          dummy_vars,
                                          sentence_meta=sentence_meta)
        category_parse_results[syntax_comb] = results

        # Now enumerate semantic forms.
        # candidate_exprs = tuple(list(iter_expressions_for_category(cat))
        #                         for cat in syntax_comb)
        # n_expr_combs = np.prod(list(map(len, candidate_exprs)))
        all_expr_combs = product_candidate_exprs(syntax_comb)
        for expr_comb in tqdm(all_expr_combs, desc="Expressions"):
          # Compute likelihood of this joint syntax--semantics assignment.
          # TODO(Jiayuan Mao @ 04/08): += logp? or += p?
          # Probably we can remove the following line?
          likelihood = 0.0
          for result in results:
            # Swap in semantic values for each token.
            sentence_semantics = result.label()[0].semantics()
            for token, token_expr in zip(token_comb, expr_comb):
              dummy_var = dummy_vars[token]
              sentence_semantics = sentence_semantics.replace(dummy_var, token_expr)
            sentence_semantics = sentence_semantics.simplify()

            try:
              lex.ontology.typecheck(sentence_semantics)
            except l.TypeException as exc:
              continue

            # print('SUCCESS: ' + '; '.join([f'{t} => {str(s)} [{str(e)}]' for t, s, e in zip(token_comb, syntax_comb, expr_comb)]), sentence_semantics, sep='\n', end='\n' + '-'*120 + '\n')

            # Compute p(meaning | syntax, sentence, parse)
            logp = sum(likelihood_fn(token_comb, syntax_comb, expr_comb,
                                     sentence_semantics, model)
                       for likelihood_fn in likelihood_fns)
            likelihood += np.exp(logp)

            # Add category priors.
            log_prior = sum(T.log(weight) for weight in syntax_weights)
            joint_score = log_prior + logp
            if joint_score == -np.inf:
              # Zero probability. Skip.
              continue

            data = tuple_unordered([token_comb, syntax_comb, expr_comb])
            new_item = (joint_score, data)
            try:
              candidate_queue.put_nowait(new_item)
            except queue.Full:
              # See if this candidate is better than the worst item.
              worst = candidate_queue.get()
              if worst[0] < joint_score:
                replacement = new_item
              else:
                replacement = worst

              candidate_queue.put_nowait(replacement)

    if candidate_queue.qsize() > 0:
      # We have a result. Quit and don't search at higher depth.
      return candidate_queue, dummy_vars
  return candidate_queue, dummy_vars