def split_many_files(input_dir: RichPath, output_dir: RichPath, train_ratio: float, valid_ratio: float, test_ratio: float, test_only_projects: Set[str]) -> None: output_paths = {} # type: Dict[str, RichPath] for split_name in ['train', 'valid', 'test', 'test-only']: graph_dir_name_for_split_type = input_dir.basename() + '-' + split_name graph_dir_for_split_type = output_dir.join( graph_dir_name_for_split_type) output_paths[split_name] = graph_dir_for_split_type graph_dir_for_split_type.make_as_dir() pool = Pool() pool.starmap(split_file, [(f, output_paths, train_ratio, valid_ratio, test_ratio, test_only_projects) for f in input_dir.get_filtered_files_in_dir('*')]) return None
def split_file(input_path: RichPath, output_paths: Dict[str, RichPath], train_ratio: float, valid_ratio: float, test_ratio: float, test_only_projects: Set[str]) -> None: train_graphs, valid_graphs, test_graphs, test_only_graphs = [], [], [], [] try: for datapoint in input_path.read_by_file_suffix(): datapoint_provenance = datapoint['Filename'] file_set = get_fold(datapoint_provenance, train_ratio, valid_ratio, test_only_projects) if file_set == 'train': train_graphs.append(datapoint) elif file_set == 'valid': valid_graphs.append(datapoint) elif file_set == 'test': test_graphs.append(datapoint) elif file_set == 'test-only': test_only_graphs.append(datapoint) except EOFError: print('Failed for file %s.' % input_path) return input_file_basename = input_path.basename() if train_ratio > 0: output_path = output_paths['train'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(train_graphs) if valid_ratio > 0: output_path = output_paths['valid'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(valid_graphs) if test_ratio > 0: output_path = output_paths['test'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(test_graphs) if len(test_only_graphs) > 0: output_path = output_paths['test-only'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(test_only_graphs)