def test_program_to_lispress_with_quotes_inside_string():
    # a string with a double-quote in it
    v, _ = mk_value_op(value='i got quotes"', schema="String", idx=0)
    program = Program(expressions=[v])
    rendered_lispress = render_pretty(program_to_lispress(program))
    assert rendered_lispress == '(#(String "i got quotes\\""))'
    round_tripped, _ = lispress_to_program(parse_lispress(rendered_lispress), 0)
    assert round_tripped == program
def unnest_line(
    s: Lispress,
    idx: Idx,
    var_id_bindings: Tuple[Tuple[str, int], ...],
) -> Tuple[List[Expression], Idx, Idx, Tuple[Tuple[str, int], ...]]:
    """
    Helper function for `_unsugared_lispress_to_program`.
    Converts a Lispress s-expression into a Program, keeping track of
    variable bindings and the last Expression index used.

    :param s: the Sexp to unnest
    :param idx: the highest used Expression idx so far
    :param var_id_bindings: map from linearized variable id to step id
    :return: A 4-tuple containing the resulting list of Expressions,
    the idx of this whole program,
    the most recent idx used (often the same as the idx of the program), and
    a map from variable names to their idx.
    """
    if not isinstance(s, list):
        try:
            # bare value
            value = loads(s)
            known_value_types = {
                str: "String",
                int: "Int",
            }
            schema = known_value_types[type(value)]
            expr, idx = mk_value_op(value=value, schema=schema, idx=idx)
            return [expr], idx, idx, var_id_bindings
        except (JSONDecodeError, KeyError):
            return unnest_line([s], idx=idx, var_id_bindings=var_id_bindings)
    elif len(s) == 0:
        expr, idx = mk_value_op(s, schema="Unit", idx=idx)
        return [expr], idx, idx, var_id_bindings
    else:
        s = [x for x in s if x != EXTERNAL_LABEL]
        hd, *tl = s
        if not isinstance(hd, str):
            # we don't know how to handle this case, so we just pack the whole thing into a generic value
            expr, idx = mk_value_op(value=s, schema="Object", idx=idx)
            return [expr], idx, idx, var_id_bindings
        elif _is_idx_str(hd):
            # argId pointer
            var_id_dict = dict(var_id_bindings)
            # look up step index for var
            assert hd in var_id_dict
            expr_id = var_id_dict[hd]
            return [], expr_id, idx, var_id_bindings
        elif is_express_idx_str(hd):
            # external reference
            return [], unwrap_idx_str(hd), idx, var_id_bindings
        elif hd == LET:
            assert (
                len(tl) >= 2 and len(tl[0]) % 2 == 0
            ), "let binding must have var_name, var_defn pairs and a body"
            result_exprs = []
            variables, *body_forms = tl
            for var_name, body in chunked(variables, 2):
                assert isinstance(var_name, str)
                exprs, arg_idx, idx, var_id_bindings = unnest_line(
                    body, idx, var_id_bindings)
                result_exprs.extend(exprs)
                var_id_bindings += ((var_name, arg_idx), )
            for body in body_forms:
                exprs, arg_idx, idx, var_id_bindings = unnest_line(
                    body, idx, var_id_bindings)
                result_exprs.extend(exprs)
            return result_exprs, arg_idx, idx, var_id_bindings
        elif hd == SEQUENCE:
            # handle programs that have multiple statements sequenced together
            result_exprs = []
            arg_idx = idx  # in case `tl` is empty
            for statement in tl:
                exprs, arg_idx, idx, var_id_bindings = unnest_line(
                    statement, idx, var_id_bindings)
                result_exprs.extend(exprs)
            return result_exprs, arg_idx, idx, var_id_bindings
        elif hd == OpType.Value.value:
            assert (
                len(tl) >= 1 and len(tl[0]) >= 1
            ), f"Values must have format '#($schema $value)'. Found '{render_compact(s)}' instead."
            ((schema, *value_tokens), ) = tl
            value = " ".join(value_tokens)
            try:
                value = loads(value)
            except JSONDecodeError:
                pass
            expr, idx = mk_value_op(value=value, schema=schema, idx=idx)
            return [expr], idx, idx, var_id_bindings
        elif is_struct_op_schema(hd):
            name = hd
            result = []
            kvs = []
            for key, val in chunked(tl, 2):
                val_exprs, val_idx, idx, var_id_bindings = unnest_line(
                    val, idx, var_id_bindings)
                result.extend(val_exprs)
                kvs.append((_named_arg_to_key(key), val_idx))
            struct_op, idx = mk_struct_op(name, dict(kvs), idx)
            return result + [struct_op], idx, idx, var_id_bindings
        else:
            # CallOp
            name = hd
            result = []
            args = []
            for a in tl:
                arg_exprs, arg_idx, idx, var_id_bindings = unnest_line(
                    a, idx, var_id_bindings)
                result.extend(arg_exprs)
                args.append(arg_idx)
            call_op, idx = mk_call_op(name, args=args, idx=idx)
            return result + [call_op], idx, idx, var_id_bindings
Exemple #3
0
def generate_express_for_topic(
    topic: str,
    kvs: Dict[str, Optional[str]],
    text: str,
    execution_trace: ExecutionTrace,
    salience_model: SalienceModelBase,
    latest_pointer: Optional[int],
    pointer_count: int,
    no_revise: bool,
    is_abandon: bool,
) -> Tuple[List[Expression], int, List[Dict[str, str]]]:
    """Generates express for given topic and key value pairs kvs.

    Args:
        topic: Topic of express
        latest_pointer: Latest pointer to node of current topic
        kvs: Dictionary of key value pairs
        execution_trace: the ExecutionTrace at the previous turn, used for refer calls
        text: Text in scope used to search for the value
        pointer_count: Current pointer count, used to generate unique index
        no_revise: If True, do not use revise calls
        salience_model: The salience model to be used for refer calls
    Returns:
        A tuple of (a list of SerializedExpression, current pointer count, failed refer calls).
    """
    expressions: List[Expression] = []

    # a map from slot name to the pointer count (expressionId) of the slot value (including None) in the dataflow graph
    pointer_count_for_slot: Dict[str, int] = {}
    failed_refer_calls: List[Dict[str, str]] = []

    if not is_abandon:
        for slot_fullname, slot_value in sorted(kvs.items()):
            _domain, slot_name = get_domain_and_slot_name(slot_fullname)
            if slot_value is None:
                # None means the slot is deleted
                expression, pointer_count = mk_unset_constraint(
                    idx=pointer_count)
                expressions.append(expression)
                pointer_count_for_slot[slot_name] = pointer_count
                continue

            slot_value = slot_value.lower()
            assert slot_value != ""

            # Best-effort conversion, i.e., use refer call only if
            # 1. the value is not mentioned in the text (current turn);
            # 2. revise call is allowed;
            # 3. the salience model can return the right value.
            use_refer = False
            if not mentioned_in_text(value=slot_value,
                                     text=text) and not no_revise:
                salience_value = salience_model.get_salient_value(
                    target_type=slot_name,
                    execution_trace=execution_trace,
                    exclude_values=set(),
                )
                if salience_value == slot_value:
                    use_refer = True
                else:
                    # records the failed salience calls so we can improve the salience model
                    failed_refer_calls.append({
                        "topic":
                        topic,
                        "slotName":
                        slot_name,
                        "targetSalienceValue":
                        slot_value,
                        "returnedSalienceValue":
                        salience_value,
                    })

            if use_refer:
                refer_expressions, pointer_count = mk_salience(
                    tpe=slot_name, idx=pointer_count)
                expressions.extend(refer_expressions)
            else:
                expression, pointer_count = mk_value_op(
                    value=slot_value,
                    schema="String",
                    idx=pointer_count,
                )
                expressions.append(expression)
                expression, pointer_count = mk_equality_constraint(
                    val=pointer_count, idx=pointer_count)
                expressions.append(expression)

            pointer_count_for_slot[slot_name] = pointer_count

    expression, pointer_count = mk_constraint(
        tpe=topic, args=pointer_count_for_slot.items(), idx=pointer_count)
    expressions.append(expression)

    if is_abandon:
        expression, pointer_count = mk_call_op(name=DataflowFn.Abandon.value,
                                               args=[pointer_count],
                                               idx=pointer_count)
        expressions.append(expression)

    elif latest_pointer is None or no_revise:
        expression, pointer_count = mk_call_op(name=DataflowFn.Find.value,
                                               args=[pointer_count],
                                               idx=pointer_count)
        expressions.append(expression)

    else:
        revise_expressions, pointer_count = mk_revise_the_main_constraint(
            tpe=topic, new_idx=pointer_count)
        expressions.extend(revise_expressions)

    return expressions, pointer_count, failed_refer_calls