def build_representation(python_source, values_lists, branch_decisions_lists,
                         tokens_per_statement, base, target_output_length,
                         output_mod):
    """Builds a partial example_dict representation of the already run source."""
    intermediate_outputs = []
    intermediate_outputs_mask = []
    branch_decisions = []
    branch_decisions_mask = []
    for values_list, branch_decisions_list in zip(values_lists,
                                                  branch_decisions_lists):
        # We select the most recent value of each statement for the code model's
        # intermediate values, to be used by the model for auxiliary losses.
        # For the trace representation, there only is one value per statement since
        # each statement has been run exactly once.
        if values_list:
            assert branch_decisions_list
            statement_output = values_list[-1]["v0"]
            if output_mod is not None:
                try:
                    statement_output %= output_mod
                except TypeError:
                    statement_output = 1
            statement_output_list = encoders.as_nary_list(
                statement_output, base, target_output_length)
            statement_output_mask = [True] * target_output_length
            branch_decision = branch_decisions_list[-1]
        else:
            assert not branch_decisions_list
            statement_output = NO_OUTPUT
            statement_output_list = [NO_OUTPUT] * target_output_length
            statement_output_mask = [False] * target_output_length
            branch_decision = python_interpreter_trace.NO_BRANCH_DECISION

        padding = tokens_per_statement - target_output_length
        intermediate_outputs.extend([NO_OUTPUT] * padding +
                                    statement_output_list)
        intermediate_outputs_mask.extend([False] * padding +
                                         statement_output_mask)
        branch_decisions.extend([0] * (tokens_per_statement - 1) +
                                [branch_decision])
        branch_decisions_mask.extend([False] * (tokens_per_statement - 1) + [
            branch_decision is not python_interpreter_trace.NO_BRANCH_DECISION
        ])
    intermediate_outputs_count = len(intermediate_outputs)
    intermediate_output_lengths = [1] * intermediate_outputs_count
    return {
        "statements": python_source,
        "length": tokens_per_statement * len(python_source.split("\n")),
        "num_statements": len(python_source.split("\n")),
        "intermediate_outputs": intermediate_outputs,
        "intermediate_outputs_mask": intermediate_outputs_mask,
        "intermediate_output_lengths": intermediate_output_lengths,
        "intermediate_outputs_count": intermediate_outputs_count,
        "branch_decisions": branch_decisions,
        "branch_decisions_count": len(branch_decisions),
        "branch_decisions_mask": branch_decisions_mask,
    }
示例#2
0
 def test_as_nary_list(self, number, base, length, target):
     encoded = encoders.as_nary_list(number, base=base, length=length)
     self.assertEqual(encoded, target)
def _generate_example_from_python_source(executor, base, python_source,
                                         tokens_per_statement,
                                         target_output_length, mod,
                                         output_mod):
    """Generates an example dict from the given statements."""
    human_readable_code = python_source
    cfg = python_programs.to_cfg(python_source)
    python_source_lines = python_source.strip().split("\n")

    # TODO(dbieber): This should occur in exactly one location.
    # (also in environment.py)
    values = {"v0": 1}
    trace_fn = python_interpreter_trace.make_trace_fn(python_source, cfg)
    # TODO(dbieber): Evaluating may have already occurred in environment.
    try:
        values = python_interpreter.evaluate_cfg(executor,
                                                 cfg,
                                                 mod=mod,
                                                 initial_values=values,
                                                 trace_fn=trace_fn,
                                                 timeout=200)
        error_type = "NoError"
    except Exception as e:  # pylint: disable=broad-except
        error_type = type(e).__name__
    target_output = values["v0"]

    if output_mod is not None:
        try:
            target_output %= output_mod
        except TypeError:
            target_output = 1

    code_features = build_representation(
        python_source, trace_fn.trace.cfg_node_index_values,
        trace_fn.trace.cfg_node_index_branch_decisions, tokens_per_statement,
        base, target_output_length, output_mod)

    use_full_lines_in_trace = False
    if use_full_lines_in_trace:
        trace_lines = [
            python_source_lines[line_index]
            for line_index in trace_fn.trace.trace_line_indexes
        ]
        trace_python_source = "\n".join(trace_lines)
    else:
        trace_control_flow_nodes = [
            cfg.nodes[cfg_node_index]
            for cfg_node_index in trace_fn.trace.trace_cfg_node_indexes
        ]
        # TODO(dbieber): This also occurs in environment `state_as_example`.
        # Refactor.
        python_source_lines = []
        for control_flow_node in trace_control_flow_nodes:
            ast_node = control_flow_node.instruction.node
            python_source_line = astunparse.unparse(ast_node,
                                                    version_info=(3, 5))
            python_source_line = python_source_line.strip()
            python_source_lines.append(python_source_line)
        trace_python_source = "\n".join(python_source_lines)
    trace_features = build_representation(
        trace_python_source, trace_fn.trace.trace_values,
        trace_fn.trace.trace_branch_decisions, tokens_per_statement, base,
        target_output_length, output_mod)

    target_output_list = encoders.as_nary_list(target_output, base,
                                               target_output_length)

    lm_text = f"{human_readable_code} SEP {target_output}"

    example_dict = {
        # human_readable_features
        "human_readable_code": human_readable_code,
        # "original_human_readable_code": human_readable_code,
        "human_readable_target_output": str(target_output),

        # target_output
        "target_output": target_output_list,
        "target_output_length": target_output_length,
        "lm_text": lm_text,
        "error_type": error_type,

        # control flow graph
        "cfg": (cfg, python_source),
        "cfg_forward": (cfg, python_source),
    }
    example_dict.update(
        {"code_" + key: value
         for key, value in code_features.items()})
    example_dict.update(
        {"trace_" + key: value
         for key, value in trace_features.items()})
    return example_dict