def remap_schema(self, dataset: Dataset): # Ground the schema to the dataset input_grounding, reversed_input_grounding = self.input_schema.ground( dataset.features ) output_grounding, reversed_output_grounding = self.output_schema.ground( dataset.features ) # Construct a map_fn that remaps the dataset schema def map_fn(example): return tz.merge( {k: example[input_grounding[k]] for k in self.input_schema.features}, {k: example[output_grounding[k]] for k in self.output_schema.features}, ) return dataset.map( map_fn, remove_columns=list(reversed_input_grounding.keys()) + list(reversed_output_grounding.keys()), )
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, mask: List[int] = None, store_compressed: bool = True, store: bool = True, num_proc: int = None, *args, **kwargs, ) -> Tuple[Dataset, List[Slice], np.ndarray]: """Apply a SliceBuilder to a dataset. Args: dataset: Dataset columns: list of columns batch_size: integer batch size mask: boolean or integer mask array, mask[i] = True means that the ith slice will be masked out store_compressed: whether to store in a compressed format store: whether to store the results along with the example in Dataset num_proc: num processes for multiprocessing *args: optional additional arguments **kwargs: optional additional keyword arguments Returns: tuple of (Dataset, list of Slices, matrix of (example, slice) membership) """ # Prepare the dataset dataset = self.prepare_dataset( dataset=dataset, columns=columns, batch_size=batch_size, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) # Compute a hash val = persistent_hash(str( dataset.identifier)) ^ dataset.hash_interactions() for i, identifier in enumerate(self.identifiers): if not mask[i]: val ^= persistent_hash( str(identifier) + str(strings_as_json(columns))) try: # Map the SliceBuilder over the dataset all_sliced_batches = [] all_slice_memberships = [] def _map_fn(batch): """Map function for processing batches. Note that using this map_fn in a stateful way is dangerous, since every invocation of this function appends to the all_slice_batches list. The .map() function will invoke this once for testing before performing the map, so we discard the first entry inserted into all_sliced_batches. """ batch, sliced_batches, slice_membership = self.process_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) all_sliced_batches.append(sliced_batches) all_slice_memberships.append(slice_membership) return batch dataset = dataset.map( _map_fn, batched=True, batch_size=batch_size, # FIXME(karan): enable this by adding logic for generating # all_sliced_batches and all_slice_memberships # when loading from cache file load_from_cache_file=False, # The cache file name is a XOR of the interaction history and the # current operation cache_file_name=str(dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")), ) # Remove the first entry (see _map_fn) all_sliced_batches = all_sliced_batches[1:] all_slice_memberships = all_slice_memberships[1:] except: # noqa # Batch the dataset, and process each batch all_batches, all_sliced_batches, all_slice_memberships = zip(*[ self.process_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) for batch in dataset.batch(batch_size) ]) # Update the dataset efficiently by reusing all_batches dataset = dataset.map( lambda examples, indices: all_batches[indices[0] // batch_size ], batched=True, batch_size=batch_size, with_indices=True, load_from_cache_file=False, # The cache file name is a XOR of the interaction history and the # current operation cache_file_name=str(dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")), ) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) slice_cache_hashes = [] for identifier in self.identifiers: slice_cache_hashes.append(val ^ persistent_hash(str(identifier))) if not num_proc or num_proc == 1: # Construct slices slices = [] for i, slice_batches in enumerate(zip(*all_sliced_batches)): slices.append( create_slice(( dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hashes[i], ))) else: # Parallelized slice construction with Pool(num_proc) as pool: slices = pool.map( create_slice, [( dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hashes[i], ) for i, slice_batches in enumerate(zip( *all_sliced_batches))], ) # TODO(karan): make this more systematic # TODO(karan): fix bug when slicing a Slice for i, sl in enumerate(slices): # # Set the Slice features # sl.info.features = dataset.features # Set the Slice category using the SliceBuilder's category sl.category = self.category # Create the lineage sl.lineage = [ (str(Dataset.__name__), dataset.identifier), ( str(self.category.capitalize()), self.identifiers[i], strings_as_json(columns), ), ] if isinstance(dataset, Slice): # Prepend the Slice's lineage instead, if the dataset was a slice sl.lineage = dataset.lineage + [( str(self.category.capitalize()), self.identifiers[i], strings_as_json(columns), )] return dataset, slices, slice_membership
def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, mask: List[int] = None, store_compressed: bool = True, store: bool = True, *args, **kwargs, ) -> Dataset: # Compute the hash for this operation # FIXME(karan): this is repeated inside process_dataset val = persistent_hash(str( dataset.identifier)) ^ dataset.hash_interactions() for i, identifier in enumerate(self.identifiers): if not mask[i]: val ^= persistent_hash( str(identifier) + str(strings_as_json(columns))) try: return dataset.map( partial( self.prepare_batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ), batched=True, batch_size=batch_size, load_from_cache_file=False, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(val)) + "-prep.arrow")), ) except: # TypeError or PicklingError or AttributeError: # noqa # Batch the dataset, and process each batch all_batches = [ self.prepare_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) for batch in dataset.batch(batch_size) ] # Update the dataset efficiently by reusing all_batches return dataset.map( lambda examples, indices: all_batches[indices[0] // batch_size ], batched=True, batch_size=batch_size, with_indices=True, load_from_cache_file=False, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(val)) + "-prep.arrow")), )