Example #1
0
def split_many_files(input_dir: RichPath, output_dir: RichPath,
                     train_ratio: float, valid_ratio: float, test_ratio: float,
                     test_only_projects: Set[str]) -> None:
    output_paths = {}  # type: Dict[str, RichPath]
    for split_name in ['train', 'valid', 'test', 'test-only']:
        graph_dir_name_for_split_type = input_dir.basename() + '-' + split_name
        graph_dir_for_split_type = output_dir.join(
            graph_dir_name_for_split_type)
        output_paths[split_name] = graph_dir_for_split_type
        graph_dir_for_split_type.make_as_dir()

    pool = Pool()
    pool.starmap(split_file,
                 [(f, output_paths, train_ratio, valid_ratio, test_ratio,
                   test_only_projects)
                  for f in input_dir.get_filtered_files_in_dir('*')])

    return None
Example #2
0
def split_file(input_path: RichPath, output_paths: Dict[str, RichPath],
               train_ratio: float, valid_ratio: float, test_ratio: float,
               test_only_projects: Set[str]) -> None:
    train_graphs, valid_graphs, test_graphs, test_only_graphs = [], [], [], []

    try:
        for datapoint in input_path.read_by_file_suffix():
            datapoint_provenance = datapoint['Filename']
            file_set = get_fold(datapoint_provenance, train_ratio, valid_ratio,
                                test_only_projects)
            if file_set == 'train':
                train_graphs.append(datapoint)
            elif file_set == 'valid':
                valid_graphs.append(datapoint)
            elif file_set == 'test':
                test_graphs.append(datapoint)
            elif file_set == 'test-only':
                test_only_graphs.append(datapoint)
    except EOFError:
        print('Failed for file %s.' % input_path)
        return

    input_file_basename = input_path.basename()

    if train_ratio > 0:
        output_path = output_paths['train'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(train_graphs)

    if valid_ratio > 0:
        output_path = output_paths['valid'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(valid_graphs)

    if test_ratio > 0:
        output_path = output_paths['test'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(test_graphs)

    if len(test_only_graphs) > 0:
        output_path = output_paths['test-only'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(test_only_graphs)