def build_representation(python_source, values_lists, branch_decisions_lists, tokens_per_statement, base, target_output_length, output_mod): """Builds a partial example_dict representation of the already run source.""" intermediate_outputs = [] intermediate_outputs_mask = [] branch_decisions = [] branch_decisions_mask = [] for values_list, branch_decisions_list in zip(values_lists, branch_decisions_lists): # We select the most recent value of each statement for the code model's # intermediate values, to be used by the model for auxiliary losses. # For the trace representation, there only is one value per statement since # each statement has been run exactly once. if values_list: assert branch_decisions_list statement_output = values_list[-1]["v0"] if output_mod is not None: try: statement_output %= output_mod except TypeError: statement_output = 1 statement_output_list = encoders.as_nary_list( statement_output, base, target_output_length) statement_output_mask = [True] * target_output_length branch_decision = branch_decisions_list[-1] else: assert not branch_decisions_list statement_output = NO_OUTPUT statement_output_list = [NO_OUTPUT] * target_output_length statement_output_mask = [False] * target_output_length branch_decision = python_interpreter_trace.NO_BRANCH_DECISION padding = tokens_per_statement - target_output_length intermediate_outputs.extend([NO_OUTPUT] * padding + statement_output_list) intermediate_outputs_mask.extend([False] * padding + statement_output_mask) branch_decisions.extend([0] * (tokens_per_statement - 1) + [branch_decision]) branch_decisions_mask.extend([False] * (tokens_per_statement - 1) + [ branch_decision is not python_interpreter_trace.NO_BRANCH_DECISION ]) intermediate_outputs_count = len(intermediate_outputs) intermediate_output_lengths = [1] * intermediate_outputs_count return { "statements": python_source, "length": tokens_per_statement * len(python_source.split("\n")), "num_statements": len(python_source.split("\n")), "intermediate_outputs": intermediate_outputs, "intermediate_outputs_mask": intermediate_outputs_mask, "intermediate_output_lengths": intermediate_output_lengths, "intermediate_outputs_count": intermediate_outputs_count, "branch_decisions": branch_decisions, "branch_decisions_count": len(branch_decisions), "branch_decisions_mask": branch_decisions_mask, }
def test_as_nary_list(self, number, base, length, target): encoded = encoders.as_nary_list(number, base=base, length=length) self.assertEqual(encoded, target)
def _generate_example_from_python_source(executor, base, python_source, tokens_per_statement, target_output_length, mod, output_mod): """Generates an example dict from the given statements.""" human_readable_code = python_source cfg = python_programs.to_cfg(python_source) python_source_lines = python_source.strip().split("\n") # TODO(dbieber): This should occur in exactly one location. # (also in environment.py) values = {"v0": 1} trace_fn = python_interpreter_trace.make_trace_fn(python_source, cfg) # TODO(dbieber): Evaluating may have already occurred in environment. try: values = python_interpreter.evaluate_cfg(executor, cfg, mod=mod, initial_values=values, trace_fn=trace_fn, timeout=200) error_type = "NoError" except Exception as e: # pylint: disable=broad-except error_type = type(e).__name__ target_output = values["v0"] if output_mod is not None: try: target_output %= output_mod except TypeError: target_output = 1 code_features = build_representation( python_source, trace_fn.trace.cfg_node_index_values, trace_fn.trace.cfg_node_index_branch_decisions, tokens_per_statement, base, target_output_length, output_mod) use_full_lines_in_trace = False if use_full_lines_in_trace: trace_lines = [ python_source_lines[line_index] for line_index in trace_fn.trace.trace_line_indexes ] trace_python_source = "\n".join(trace_lines) else: trace_control_flow_nodes = [ cfg.nodes[cfg_node_index] for cfg_node_index in trace_fn.trace.trace_cfg_node_indexes ] # TODO(dbieber): This also occurs in environment `state_as_example`. # Refactor. python_source_lines = [] for control_flow_node in trace_control_flow_nodes: ast_node = control_flow_node.instruction.node python_source_line = astunparse.unparse(ast_node, version_info=(3, 5)) python_source_line = python_source_line.strip() python_source_lines.append(python_source_line) trace_python_source = "\n".join(python_source_lines) trace_features = build_representation( trace_python_source, trace_fn.trace.trace_values, trace_fn.trace.trace_branch_decisions, tokens_per_statement, base, target_output_length, output_mod) target_output_list = encoders.as_nary_list(target_output, base, target_output_length) lm_text = f"{human_readable_code} SEP {target_output}" example_dict = { # human_readable_features "human_readable_code": human_readable_code, # "original_human_readable_code": human_readable_code, "human_readable_target_output": str(target_output), # target_output "target_output": target_output_list, "target_output_length": target_output_length, "lm_text": lm_text, "error_type": error_type, # control flow graph "cfg": (cfg, python_source), "cfg_forward": (cfg, python_source), } example_dict.update( {"code_" + key: value for key, value in code_features.items()}) example_dict.update( {"trace_" + key: value for key, value in trace_features.items()}) return example_dict