def _augment(transform_payload: List[dict]) -> List[str]: # Transform code transform_payload = json.dumps(transform_payload) is_successful, stdout, stderr = dispatch_to_node("transform.js", transform_payload) if is_successful: try: transformed = json.loads(stdout) except json.JSONDecodeError: # Transformation failed in Node.js (got malformed stdout), so don't transform logger.error(f"JSONDecodeError in _augment transform input: {transform_payload}") logger.error(f"JSONDecodeError in _augment transform stdout: {stdout}") logger.error(f"JSONDecodeError in _augment transform stderr: {stderr}") transformed = [prog["src"] for prog in transform_payload] except Exception as e: # Transformation failed in Node.js (got malformed stdout), so don't transform logger.error(f"Exception (maybe JSONDecodeError) in _augment: {e}") logger.error(f"Exception in _augment transform input: {transform_payload}") logger.error(f"Exception in _augment transform stdout: {stdout}") logger.error(f"Exception in _augment transform stderr: {stderr}") transformed = [prog["src"] for prog in transform_payload] else: logger.error("Non-zero exit code in _augment:") logger.error(f"Exception in _augment transform input: {transform_payload}") logger.error(f"Exception in _augment transform stdout: {stdout}") logger.error(f"Exception in _augment transform stderr: {stderr}") transformed = [prog["src"] for prog in transform_payload] assert isinstance(transformed, list) return transformed
def filter_dataset(path: str, out_path: str, require_fields=[], exclude_transform_errors=False): logger.debug(f"Requiring fields {require_fields}") full_path = pathlib.Path(path).resolve() f = gzip.open(full_path, "rb") if path.endswith(".jsonl.gz") else full_path.open("r") reader = jsonlines.Reader(f) total_lines = 0 examples = [] logger.debug(f"Loading {full_path}") for json_dict in tqdm.tqdm(reader, desc=full_path.name): total_lines += 1 # Check all required fields are present if any([field not in json_dict or not json_dict[field] for field in require_fields]): continue # We need the identifier (method name) as a label. Filter invalid identifiers if "identifier" in require_fields and _valid_identifier_regex.match(json_dict["identifier"]) == None: continue # Try to parse/transform the code, and filter out if we can't if exclude_transform_errors: # Set up transformation input transform_payload = [ dict( src=json_dict["function"], # TODO: this key should be "code" for supervised set augmentations=[{"fn": "identity_ast2ast"}], ) ] transform_payload = json.dumps(transform_payload) stdout, stderr = dispatch_to_node("transform.js", transform_payload) if stderr: continue examples.append(json_dict) if total_lines % 1 == 0: logger.debug(f"Filtered jsonl to {len(examples)}/{total_lines}") f.close() logger.debug(f"DONE: Filtered jsonl to {len(examples)}/{total_lines}") # TODO: Subsample # Write output full_out_path = pathlib.Path(out_path).resolve() f = gzip.open(full_out_path, "wb") if out_path.endswith(".jsonl.gz") else full_out_path.open("w") writer = jsonlines.Writer(f) logger.debug(f"Writing output to {full_out_path}...") writer.write_all(examples) logger.debug(f"DONE writing") f.close()
def has_transform_error(json_dict): # Try to parse/transform the code, and filter out if we can't # Set up transformation input json_dict = _fix_json_dict(json_dict, ["function"], "function", "identifier") transform_payload = [ dict(src=json_dict["function"], augmentations=[{"fn": "identity_ast2ast"}],) # TODO: this key should be "code" for supervised set ] transform_payload = json.dumps(transform_payload) stdout, stderr = dispatch_to_node("transform.js", transform_payload) if stderr: return True, None return False, json_dict