def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, *args, **kwargs, ) -> None: """Apply a preparation function to the dataset. Use this to update attributes of `self`. Args: dataset: dataset columns: list of columns batch_size: batch size for preparation *args: optional additional arguments **kwargs: optional additional keyword arguments """ # Set the data format with dataset.format( columns + self._filter_prerequisite_columns(columns, dataset.column_names) ): # Batch the dataset, and prepare each batch for batch in dataset.batch(batch_size): try: # Check if the `prepare_batch` function has been implemented self.prepare_batch( batch=batch, columns=columns, *args, **kwargs, ) except NotImplementedError: break
def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, *args, **kwargs, ) -> None: """Preparation that is applied before the CachedOperation. Many CachedOperations require a full pass over the dataset to precompute some variables before the core operation can actually be applied e.g. to create a Bag-of-Words representation, constructing a dataset vocabulary to keep only tokens that are frequently seen across the dataset. Args: dataset: Dataset columns: list of columns batch_size: batch size for .map(..) Returns: updated Dataset """ # Set the data format dataset.set_format(columns) # Batch the dataset, and prepare each batch for batch in dataset.batch(batch_size): try: # Check if the `prepare_batch` function has been implemented self.prepare_batch( batch=batch, columns=columns, *args, **kwargs, ) except NotImplementedError: break # Reset the data format dataset.reset_format()
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, mask: List[int] = None, store_compressed: bool = True, store: bool = True, num_proc: int = None, *args, **kwargs, ) -> Tuple[Dataset, List[Slice], np.ndarray]: """Apply a SliceBuilder to a dataset. Args: dataset: Dataset columns: list of columns batch_size: integer batch size mask: boolean or integer mask array, mask[i] = True means that the ith slice will be masked out store_compressed: whether to store in a compressed format store: whether to store the results along with the example in Dataset num_proc: num processes for multiprocessing *args: optional additional arguments **kwargs: optional additional keyword arguments Returns: tuple of (Dataset, list of Slices, matrix of (example, slice) membership) """ # Prepare the dataset dataset = self.prepare_dataset( dataset=dataset, columns=columns, batch_size=batch_size, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) # Compute a hash val = persistent_hash(str( dataset.identifier)) ^ dataset.hash_interactions() for i, identifier in enumerate(self.identifiers): if not mask[i]: val ^= persistent_hash( str(identifier) + str(strings_as_json(columns))) try: # Map the SliceBuilder over the dataset all_sliced_batches = [] all_slice_memberships = [] def _map_fn(batch): """Map function for processing batches. Note that using this map_fn in a stateful way is dangerous, since every invocation of this function appends to the all_slice_batches list. The .map() function will invoke this once for testing before performing the map, so we discard the first entry inserted into all_sliced_batches. """ batch, sliced_batches, slice_membership = self.process_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) all_sliced_batches.append(sliced_batches) all_slice_memberships.append(slice_membership) return batch dataset = dataset.map( _map_fn, batched=True, batch_size=batch_size, # FIXME(karan): enable this by adding logic for generating # all_sliced_batches and all_slice_memberships # when loading from cache file load_from_cache_file=False, # The cache file name is a XOR of the interaction history and the # current operation cache_file_name=str(dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")), ) # Remove the first entry (see _map_fn) all_sliced_batches = all_sliced_batches[1:] all_slice_memberships = all_slice_memberships[1:] except: # noqa # Batch the dataset, and process each batch all_batches, all_sliced_batches, all_slice_memberships = zip(*[ self.process_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) for batch in dataset.batch(batch_size) ]) # Update the dataset efficiently by reusing all_batches dataset = dataset.map( lambda examples, indices: all_batches[indices[0] // batch_size ], batched=True, batch_size=batch_size, with_indices=True, load_from_cache_file=False, # The cache file name is a XOR of the interaction history and the # current operation cache_file_name=str(dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")), ) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) slice_cache_hashes = [] for identifier in self.identifiers: slice_cache_hashes.append(val ^ persistent_hash(str(identifier))) if not num_proc or num_proc == 1: # Construct slices slices = [] for i, slice_batches in enumerate(zip(*all_sliced_batches)): slices.append( create_slice(( dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hashes[i], ))) else: # Parallelized slice construction with Pool(num_proc) as pool: slices = pool.map( create_slice, [( dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hashes[i], ) for i, slice_batches in enumerate(zip( *all_sliced_batches))], ) # TODO(karan): make this more systematic # TODO(karan): fix bug when slicing a Slice for i, sl in enumerate(slices): # # Set the Slice features # sl.info.features = dataset.features # Set the Slice category using the SliceBuilder's category sl.category = self.category # Create the lineage sl.lineage = [ (str(Dataset.__name__), dataset.identifier), ( str(self.category.capitalize()), self.identifiers[i], strings_as_json(columns), ), ] if isinstance(dataset, Slice): # Prepend the Slice's lineage instead, if the dataset was a slice sl.lineage = dataset.lineage + [( str(self.category.capitalize()), self.identifiers[i], strings_as_json(columns), )] return dataset, slices, slice_membership
def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, mask: List[int] = None, store_compressed: bool = True, store: bool = True, *args, **kwargs, ) -> Dataset: # Compute the hash for this operation # FIXME(karan): this is repeated inside process_dataset val = persistent_hash(str( dataset.identifier)) ^ dataset.hash_interactions() for i, identifier in enumerate(self.identifiers): if not mask[i]: val ^= persistent_hash( str(identifier) + str(strings_as_json(columns))) try: return dataset.map( partial( self.prepare_batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ), batched=True, batch_size=batch_size, load_from_cache_file=False, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(val)) + "-prep.arrow")), ) except: # TypeError or PicklingError or AttributeError: # noqa # Batch the dataset, and process each batch all_batches = [ self.prepare_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) for batch in dataset.batch(batch_size) ] # Update the dataset efficiently by reusing all_batches return dataset.map( lambda examples, indices: all_batches[indices[0] // batch_size ], batched=True, batch_size=batch_size, with_indices=True, load_from_cache_file=False, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(val)) + "-prep.arrow")), )
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, # mask: List[int] = None, num_proc: int = None, *args, **kwargs, ) -> Tuple[List[Slice], np.ndarray]: # Create slices using the dataset slices = [Slice(dataset) for _ in range(len(self.identifiers))] all_slice_memberships = [] # Batch the dataset, and process each batch for batch in dataset.batch(batch_size): # Process the batch _, slice_memberships = self.process_batch( batch=batch, columns=columns, *args, **kwargs, ) # Keep track of the slice memberships all_slice_memberships.append(slice_memberships) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) for i, sl in enumerate(slices): # Set the visible rows for each slice sl.set_visible_rows(np.where(slice_membership[:, i])[0]) # Set the Slice category using the SliceBuilder's category sl.category = self.category # Append the the lineage sl.add_to_lineage( category=str(self.category.capitalize()), identifier=self.identifiers[i], columns=strings_as_json(columns), ) # # Create the lineage # sl.lineage = [ # (str(Dataset.__name__), dataset.identifier), # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ), # ] # if isinstance(dataset, Slice): # # Prepend the Slice's lineage instead, if the dataset was a slice # sl.lineage = dataset.lineage + [ # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ) # ] return slices, slice_membership
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, # mask: List[int] = None, num_proc: int = None, *args, **kwargs, ) -> Tuple[List[Slice], np.ndarray]: """Apply a SliceBuilder to a dataset. Args: dataset: Dataset columns: list of columns batch_size: integer batch size # mask: boolean or integer mask array, mask[i] = True means that the ith # slice will be masked out num_proc: num processes for multiprocessing *args: optional additional arguments **kwargs: optional additional keyword arguments Returns: tuple of (Dataset, list of Slices, matrix of (example, slice) membership) """ # # Compute a hash # val = persistent_hash(str(dataset.identifier)) ^ dataset.hash_interactions() # for i, identifier in enumerate(self.identifiers): # if not mask[i]: # val ^= persistent_hash(str(identifier) # + str(strings_as_json(columns))) # try: # # Map the SliceBuilder over the dataset # all_sliced_batches = [] # all_slice_memberships = [] # # def _map_fn(batch): # """Map function for processing batches. # # Note that using this map_fn in a stateful way is # dangerous, since every invocation of this function # appends to the all_slice_batches list. The .map() # function will invoke this once for testing before # performing the map, so we discard the first entry # inserted into all_sliced_batches. # """ # batch, sliced_batches, slice_membership = self.process_batch( # batch=batch, # columns=columns, # mask=mask, # store_compressed=store_compressed, # store=store, # *args, # **kwargs, # ) # all_sliced_batches.append(sliced_batches) # all_slice_memberships.append(slice_membership) # return batch # # dataset = dataset.map( # _map_fn, # batched=True, # batch_size=batch_size, # # FIXME(karan): enable this by adding logic for generating # # all_sliced_batches and all_slice_memberships # # when loading from cache file # load_from_cache_file=False, # # The cache file name is a XOR of the interaction history and the # # current operation # cache_file_name=str( # dataset.logdir / ("cache-" + str(abs(val)) + ".arrow") # ), # ) # # # Remove the first entry (see _map_fn) # all_sliced_batches = all_sliced_batches[1:] # all_slice_memberships = all_slice_memberships[1:] # # except: # noqa # all_batches, all_sliced_batches, all_slice_memberships = zip( # *[ # self.process_batch( # batch=batch, # columns=columns, # mask=mask, # store_compressed=store_compressed, # store=store, # *args, # **kwargs, # ) # for batch in dataset.batch(batch_size) # ] # ) # # Update the dataset efficiently by reusing all_batches # dataset = dataset.map( # lambda examples, indices: all_batches[indices[0] // batch_size], # batched=True, # batch_size=batch_size, # with_indices=True, # load_from_cache_file=False, # # The cache file name is a XOR of the interaction history and the # # current operation # cache_file_name=str( # dataset.logdir / ("cache-" + str(abs(val)) + ".arrow") # ), # ) # Create slices slices = [Slice() for _ in range(len(self.identifiers))] all_slice_memberships = [] # Batch the dataset, and process each batch for batch in dataset.batch(batch_size): # Process the batch sliced_batches, slice_memberships = self.process_batch( batch=batch, columns=columns, *args, **kwargs, ) # Incrementally build the slices for sl, sl_batch in zip(slices, sliced_batches): sl._dataset.append(sl_batch) # Keep track of the slice memberships all_slice_memberships.append(slice_memberships) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) # slice_cache_hashes = [] # for identifier in self.identifiers: # slice_cache_hashes.append(val ^ persistent_hash(str(identifier))) # if not num_proc or num_proc == 1: # # Construct slices # slices = [] # for i, slice_batches in enumerate(zip(*all_sliced_batches)): # slices.append( # create_slice( # ( # dataset, # slice_membership, # slice_batches, # i, # batch_size, # slice_cache_hashes[i], # ) # ) # ) # else: # # Parallelized slice construction # with Pool(num_proc) as pool: # slices = pool.map( # create_slice, # [ # ( # dataset, # slice_membership, # slice_batches, # i, # batch_size, # slice_cache_hashes[i], # ) # for i, slice_batches in enumerate(zip(*all_sliced_batches)) # ], # ) for i, sl in enumerate(slices): # Set the Slice category using the SliceBuilder's category sl.category = self.category # Append the the lineage sl.add_to_lineage( category=str(self.category.capitalize()), identifier=self.identifiers[i], columns=strings_as_json(columns), ) # # Create the lineage # sl.lineage = [ # (str(Dataset.__name__), dataset.identifier), # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ), # ] # if isinstance(dataset, Slice): # # Prepend the Slice's lineage instead, if the dataset was a slice # sl.lineage = dataset.lineage + [ # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ) # ] return slices, slice_membership