def build_the_dataset(data_prefix, name, data_impl, num_samples, seq_length, seed, skip_warmup, build_index_mappings=True): """Build train/valid/test datasets.""" indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] print_rank_0(' {}:'.format(name)) print_rank_0(' no. of documents:{}'.format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) dataset = GPT2Dataset(name, data_prefix, documents, indexed_dataset, num_samples, seq_length, seed, build_index_mappings=build_index_mappings) return dataset
def build_train_valid_test_datasets( data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup, ): """Build train, valid, and test datasets.""" # Indexed dataset. indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. print_rank_0(" > dataset split:") def print_split_stats(name, index): print_rank_0(" {}:".format(name)) print_rank_0(" document indices in [{}, {}) total of {} " "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index])) print_split_stats("train", 0) print_split_stats("validation", 1) print_split_stats("test", 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) dataset = GPT2Dataset( name, data_prefix, documents, indexed_dataset, train_valid_test_num_samples[index], seq_length, seed, ) return dataset train_dataset = build_dataset(0, "train") valid_dataset = build_dataset(1, "valid") test_dataset = build_dataset(2, "test") return train_dataset, valid_dataset, test_dataset
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): """Build indexed dataset.""" print_rank_0(' > building dataset index ...') start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) print_rank_0(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) print_rank_0(' number of documents: {}'.format( indexed_dataset.sizes.shape[0])) return indexed_dataset
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): print_rank_0(' > building dataset index ...') start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] print_rank_0(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) print_rank_0(' > indexed dataset stats:') print_rank_0( ' number of documents: {}'.format(indexed_dataset.doc_idx.shape[0] - 1)) print_rank_0(' number of sentences: {}'.format( indexed_dataset.sizes.shape[0])) return indexed_dataset