def test_program_to_lispress_with_quotes_inside_string(): # a string with a double-quote in it v, _ = mk_value_op(value='i got quotes"', schema="String", idx=0) program = Program(expressions=[v]) rendered_lispress = render_pretty(program_to_lispress(program)) assert rendered_lispress == '(#(String "i got quotes\\""))' round_tripped, _ = lispress_to_program(parse_lispress(rendered_lispress), 0) assert round_tripped == program
def unnest_line( s: Lispress, idx: Idx, var_id_bindings: Tuple[Tuple[str, int], ...], ) -> Tuple[List[Expression], Idx, Idx, Tuple[Tuple[str, int], ...]]: """ Helper function for `_unsugared_lispress_to_program`. Converts a Lispress s-expression into a Program, keeping track of variable bindings and the last Expression index used. :param s: the Sexp to unnest :param idx: the highest used Expression idx so far :param var_id_bindings: map from linearized variable id to step id :return: A 4-tuple containing the resulting list of Expressions, the idx of this whole program, the most recent idx used (often the same as the idx of the program), and a map from variable names to their idx. """ if not isinstance(s, list): try: # bare value value = loads(s) known_value_types = { str: "String", int: "Int", } schema = known_value_types[type(value)] expr, idx = mk_value_op(value=value, schema=schema, idx=idx) return [expr], idx, idx, var_id_bindings except (JSONDecodeError, KeyError): return unnest_line([s], idx=idx, var_id_bindings=var_id_bindings) elif len(s) == 0: expr, idx = mk_value_op(s, schema="Unit", idx=idx) return [expr], idx, idx, var_id_bindings else: s = [x for x in s if x != EXTERNAL_LABEL] hd, *tl = s if not isinstance(hd, str): # we don't know how to handle this case, so we just pack the whole thing into a generic value expr, idx = mk_value_op(value=s, schema="Object", idx=idx) return [expr], idx, idx, var_id_bindings elif _is_idx_str(hd): # argId pointer var_id_dict = dict(var_id_bindings) # look up step index for var assert hd in var_id_dict expr_id = var_id_dict[hd] return [], expr_id, idx, var_id_bindings elif is_express_idx_str(hd): # external reference return [], unwrap_idx_str(hd), idx, var_id_bindings elif hd == LET: assert ( len(tl) >= 2 and len(tl[0]) % 2 == 0 ), "let binding must have var_name, var_defn pairs and a body" result_exprs = [] variables, *body_forms = tl for var_name, body in chunked(variables, 2): assert isinstance(var_name, str) exprs, arg_idx, idx, var_id_bindings = unnest_line( body, idx, var_id_bindings) result_exprs.extend(exprs) var_id_bindings += ((var_name, arg_idx), ) for body in body_forms: exprs, arg_idx, idx, var_id_bindings = unnest_line( body, idx, var_id_bindings) result_exprs.extend(exprs) return result_exprs, arg_idx, idx, var_id_bindings elif hd == SEQUENCE: # handle programs that have multiple statements sequenced together result_exprs = [] arg_idx = idx # in case `tl` is empty for statement in tl: exprs, arg_idx, idx, var_id_bindings = unnest_line( statement, idx, var_id_bindings) result_exprs.extend(exprs) return result_exprs, arg_idx, idx, var_id_bindings elif hd == OpType.Value.value: assert ( len(tl) >= 1 and len(tl[0]) >= 1 ), f"Values must have format '#($schema $value)'. Found '{render_compact(s)}' instead." ((schema, *value_tokens), ) = tl value = " ".join(value_tokens) try: value = loads(value) except JSONDecodeError: pass expr, idx = mk_value_op(value=value, schema=schema, idx=idx) return [expr], idx, idx, var_id_bindings elif is_struct_op_schema(hd): name = hd result = [] kvs = [] for key, val in chunked(tl, 2): val_exprs, val_idx, idx, var_id_bindings = unnest_line( val, idx, var_id_bindings) result.extend(val_exprs) kvs.append((_named_arg_to_key(key), val_idx)) struct_op, idx = mk_struct_op(name, dict(kvs), idx) return result + [struct_op], idx, idx, var_id_bindings else: # CallOp name = hd result = [] args = [] for a in tl: arg_exprs, arg_idx, idx, var_id_bindings = unnest_line( a, idx, var_id_bindings) result.extend(arg_exprs) args.append(arg_idx) call_op, idx = mk_call_op(name, args=args, idx=idx) return result + [call_op], idx, idx, var_id_bindings
def generate_express_for_topic( topic: str, kvs: Dict[str, Optional[str]], text: str, execution_trace: ExecutionTrace, salience_model: SalienceModelBase, latest_pointer: Optional[int], pointer_count: int, no_revise: bool, is_abandon: bool, ) -> Tuple[List[Expression], int, List[Dict[str, str]]]: """Generates express for given topic and key value pairs kvs. Args: topic: Topic of express latest_pointer: Latest pointer to node of current topic kvs: Dictionary of key value pairs execution_trace: the ExecutionTrace at the previous turn, used for refer calls text: Text in scope used to search for the value pointer_count: Current pointer count, used to generate unique index no_revise: If True, do not use revise calls salience_model: The salience model to be used for refer calls Returns: A tuple of (a list of SerializedExpression, current pointer count, failed refer calls). """ expressions: List[Expression] = [] # a map from slot name to the pointer count (expressionId) of the slot value (including None) in the dataflow graph pointer_count_for_slot: Dict[str, int] = {} failed_refer_calls: List[Dict[str, str]] = [] if not is_abandon: for slot_fullname, slot_value in sorted(kvs.items()): _domain, slot_name = get_domain_and_slot_name(slot_fullname) if slot_value is None: # None means the slot is deleted expression, pointer_count = mk_unset_constraint( idx=pointer_count) expressions.append(expression) pointer_count_for_slot[slot_name] = pointer_count continue slot_value = slot_value.lower() assert slot_value != "" # Best-effort conversion, i.e., use refer call only if # 1. the value is not mentioned in the text (current turn); # 2. revise call is allowed; # 3. the salience model can return the right value. use_refer = False if not mentioned_in_text(value=slot_value, text=text) and not no_revise: salience_value = salience_model.get_salient_value( target_type=slot_name, execution_trace=execution_trace, exclude_values=set(), ) if salience_value == slot_value: use_refer = True else: # records the failed salience calls so we can improve the salience model failed_refer_calls.append({ "topic": topic, "slotName": slot_name, "targetSalienceValue": slot_value, "returnedSalienceValue": salience_value, }) if use_refer: refer_expressions, pointer_count = mk_salience( tpe=slot_name, idx=pointer_count) expressions.extend(refer_expressions) else: expression, pointer_count = mk_value_op( value=slot_value, schema="String", idx=pointer_count, ) expressions.append(expression) expression, pointer_count = mk_equality_constraint( val=pointer_count, idx=pointer_count) expressions.append(expression) pointer_count_for_slot[slot_name] = pointer_count expression, pointer_count = mk_constraint( tpe=topic, args=pointer_count_for_slot.items(), idx=pointer_count) expressions.append(expression) if is_abandon: expression, pointer_count = mk_call_op(name=DataflowFn.Abandon.value, args=[pointer_count], idx=pointer_count) expressions.append(expression) elif latest_pointer is None or no_revise: expression, pointer_count = mk_call_op(name=DataflowFn.Find.value, args=[pointer_count], idx=pointer_count) expressions.append(expression) else: revise_expressions, pointer_count = mk_revise_the_main_constraint( tpe=topic, new_idx=pointer_count) expressions.extend(revise_expressions) return expressions, pointer_count, failed_refer_calls