def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, *args, **kwargs, ) -> None: """Apply a preparation function to the dataset. Use this to update attributes of `self`. Args: dataset: dataset columns: list of columns batch_size: batch size for preparation *args: optional additional arguments **kwargs: optional additional keyword arguments """ # Set the data format with dataset.format( columns + self._filter_prerequisite_columns(columns, dataset.column_names) ): # Batch the dataset, and prepare each batch for batch in dataset.batch(batch_size): try: # Check if the `prepare_batch` function has been implemented self.prepare_batch( batch=batch, columns=columns, *args, **kwargs, ) except NotImplementedError: break
def __init__(self): # Create a fake dataset self.dataset = Dataset.from_batch( { "text_a": [ "Before the actor slept, the senator ran.", "The lawyer knew that the judges shouted.", "If the actor slept, the judge saw the artist.", "The lawyers resigned, or the artist slept.", ], "text_b": [ "The actor slept.", "The judges shouted.", "The actor slept.", "The artist slept.", ], "label": [0, 0, 1, 1], "z": [1, 0, 1, 0], "fast": [False, True, True, False], }, identifier=Identifier(_name="MockDataset", version="2.0"), ) # Keep a copy of the original self.original_dataset = deepcopy(self.dataset) assert len(self.dataset) == 4
def __init__(self): # Create a fake batch of data self.batch = { "text": [ "The man is walking.", "The man is running.", "The woman is sprinting.", "The woman is resting.", "The hobbit is flying.", "The hobbit is swimming.", ], "label": [0, 0, 1, 1, 0, 0], "z": [1, 0, 1, 0, 1, 0], "fast": [False, True, True, False, False, False], "metadata": [ {"source": "real"}, {"source": "real"}, {"source": "real"}, {"source": "real"}, {"source": "fictional"}, {"source": "fictional"}, ], } # Create a fake dataset self.dataset = Dataset.from_batch( self.batch, identifier=Identifier(_name="MockDataset", version="1.0"), ) # Keep a copy of the original self.original_dataset = deepcopy(self.dataset) assert len(self.dataset) == 6
def test_load_dataset(self): # Load the first 20 examples of the boolq dataset dataset = Dataset.load_dataset("boolq", split="train[:20]") # Check that we got 20 examples self.assertTrue(isinstance(dataset, Dataset)) self.assertEqual(len(dataset), 20)
def test_from_jsonl(self): # Create a temporary directory os.makedirs("tmp", exist_ok=True) # Create a json file with data with jsonlines.open("tmp/data.jsonl", "w") as writer: writer.write_all( transpose_batch({ "a": [1, 2, 3], "b": [True, False, True], "c": ["x", "y", "z"], "d": [{ "e": 2 }, { "e": 3 }, { "e": 4 }], })) # Load the dataset dataset = Dataset.from_jsonl( json_path="tmp/data.jsonl", identifier=Identifier(_name="MockJSONDataset"), ) self.assertEqual(set(dataset.column_names), {"a", "b", "c", "d", "index"}) self.assertEqual(len(dataset), 3) # Remove the temporary directory shutil.rmtree("tmp")
def test_load_dataset(self): # Load the first 20 examples of the boolq dataset dataset = Dataset.load_dataset("boolq", split="train[:20]", dataset_fmt="in_memory") # Check that we got 20 examples self.assertTrue(isinstance(dataset, Dataset)) self.assertEqual(len(dataset), 20) self.assertTrue( isinstance(dataset.identifier.parameters["version"], str))
def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, *args, **kwargs, ) -> None: """Preparation that is applied before the CachedOperation. Many CachedOperations require a full pass over the dataset to precompute some variables before the core operation can actually be applied e.g. to create a Bag-of-Words representation, constructing a dataset vocabulary to keep only tokens that are frequently seen across the dataset. Args: dataset: Dataset columns: list of columns batch_size: batch size for .map(..) Returns: updated Dataset """ # Set the data format dataset.set_format(columns) # Batch the dataset, and prepare each batch for batch in dataset.batch(batch_size): try: # Check if the `prepare_batch` function has been implemented self.prepare_batch( batch=batch, columns=columns, *args, **kwargs, ) except NotImplementedError: break # Reset the data format dataset.reset_format()
def test_interleave(self): # Interleave the dataset with itself dataset = Dataset.interleave( [self.testbed.dataset, self.testbed.dataset], identifier=Identifier(_name="MockInterleavedDataset"), ) # Check that the elements match up for i, x in enumerate(dataset): self.assertEqual(x, self.testbed.dataset[i // 2]) self.assertEqual(len(dataset), len(self.testbed.dataset) * 2)
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, ) -> Dataset: """Apply the cached operation to a dataset.""" return dataset.update( partial(self.process_batch, columns=columns), batched=True, batch_size=batch_size, )
def test_chain(self): # Chain the dataset with itself dataset = Dataset.chain( [self.testbed.dataset, self.testbed.dataset], identifier=Identifier(_name="MockChainedDataset"), ) # Check that the elements match up for i, x in enumerate(dataset): self.assertEqual( x, self.testbed.dataset[i % len(self.testbed.dataset)]) self.assertEqual(len(dataset), len(self.testbed.dataset) * 2)
def test_save_load(self): # Create a temporary directory os.makedirs("tmp", exist_ok=True) # Save the dataset to disk self.testbed.dataset.save_to_disk(path="tmp") # Load the dataset from disk dataset = Dataset.load_from_disk(path="tmp") # Remove the temporary directory shutil.rmtree("tmp") self.assertEqual(dataset.features, self.testbed.dataset.features)
def test_save_load(self): # Create a temporary directory os.mkdir("tmp") # Save the dataset to disk self.testbed.dataset.save(path="tmp") # Load the dataset from disk dataset = Dataset.load(path="tmp") # Remove the temporary directory shutil.rmtree("tmp") self.assertEqual(dataset.features, self.testbed.dataset.features)
def load(cls, path: str) -> DevBench: """Load a devbench from disk. Args: path: string path to the devbench directory Returns: """ # Path to the save directory savedir = pathlib.Path(path) # Load all the slices slices = [] for sl_path in tqdm(list((savedir / "slices").glob("*"))): try: slices.append(Slice.load_from_disk(str(sl_path))) except FileNotFoundError: continue # Load dataset dataset = Dataset.load_from_disk(str(savedir / "dataset")) # Load metrics metrics = dill.load(open(str(savedir / "metrics.dill"), "rb")) # Load metrics aggregators = dill.load(open(str(savedir / "aggregators.dill"), "rb")) # Load metadata _ = dill.load(open(str(savedir / "metadata.dill"), "rb")) # Create the devbench devbench = cls(dataset=dataset, ) # Set previously stored metrics devbench.metrics = metrics # Set previously stored aggregators devbench.aggregators = aggregators # Set the slices devbench.add_slices(slices) # Load version info with open(str(savedir / "version.dill"), "rb") as f: devbench._loads_version(f.read()) return devbench
def remap_schema(self, dataset: Dataset): # Ground the schema to the dataset input_grounding, reversed_input_grounding = self.input_schema.ground( dataset.features ) output_grounding, reversed_output_grounding = self.output_schema.ground( dataset.features ) # Construct a map_fn that remaps the dataset schema def map_fn(example): return tz.merge( {k: example[input_grounding[k]] for k in self.input_schema.features}, {k: example[output_grounding[k]] for k in self.output_schema.features}, ) return dataset.map( map_fn, remove_columns=list(reversed_input_grounding.keys()) + list(reversed_output_grounding.keys()), )
def test_from_batches(self): # Build a dataset from multiple batches dataset = Dataset.from_batches( [{ "a": [1, 2, 3], "b": [True, False, True], "c": ["x", "y", "z"], "d": [{ "e": 2 }, { "e": 3 }, { "e": 4 }], }] * 3, identifier=Identifier(_name="MyDataset"), ) self.assertEqual(set(dataset.column_names), {"a", "b", "c", "d", "index"}) self.assertEqual(len(dataset), 9)
def evaluate( self, dataset: Dataset, input_columns: List[str], output_columns: List[str], batch_size: int = 32, metrics: List[str] = None, coerce_fn: Callable = None, ): # TODO(karan): generalize to TF2 # Reset the dataset format dataset.reset_format() dataset.set_format(columns=input_columns + output_columns) # TODO(karan): check that the Dataset conforms to the task definition # TODO(karan): figure out how the output_columns will be used by the metrics pass predictions = [] targets = [] # Loop and apply the prediction function # TODO(karan): not using .map() here in order to get more fine-grained # control over devices for idx in range(0, len(dataset), batch_size): # Create the batch batch = dataset[idx:idx + batch_size] # Predict on the batch prediction_dict = self.predict_batch(batch=batch, input_columns=input_columns) # Coerce the predictions if coerce_fn: prediction_dict = coerce_fn(prediction_dict) # Grab the raw target key/values target_dict = tz.keyfilter(lambda k: k in output_columns, batch) # TODO(karan): general version for non-classification problems # TODO(karan): move this to the right device if self.task.classification(): target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict) # TODO(karan): incremental metric computation here # Append the predictions and targets predictions.append(prediction_dict) targets.append(target_dict) # Consolidate the predictions and targets if self.task.classification(): # TODO(karan): Need to store predictions and outputs from the model predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *predictions) targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets) else: predictions = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *predictions) targets = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *targets) # Compute the metrics # TODO(karan): generalize this code to support metric computation for any task # Assumes classification, so the output_columns contains a single key for the # label if self.task.classification(): assert len( output_columns) == 1 # , "Only supports classification." num_classes = self.task.output_schema.features[list( self.task.output_schema.keys())[0]].num_classes labels = targets[list(targets.keys())[0]] if metrics is None: if self.task is None: raise ValueError( "Must specify metrics if model not associated with task") metrics = self.task.metrics pred = predictions["pred"].to(self.device) target = labels.to(self.device) evaluation_dict = { metric: compute_metric(metric, pred, target, num_classes) for metric in metrics } # Reset the data format dataset.reset_format() return evaluation_dict
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, # mask: List[int] = None, num_proc: int = None, *args, **kwargs, ) -> Tuple[List[Slice], np.ndarray]: # Create slices using the dataset slices = [Slice(dataset) for _ in range(len(self.identifiers))] all_slice_memberships = [] # Batch the dataset, and process each batch for batch in dataset.batch(batch_size): # Process the batch _, slice_memberships = self.process_batch( batch=batch, columns=columns, *args, **kwargs, ) # Keep track of the slice memberships all_slice_memberships.append(slice_memberships) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) for i, sl in enumerate(slices): # Set the visible rows for each slice sl.set_visible_rows(np.where(slice_membership[:, i])[0]) # Set the Slice category using the SliceBuilder's category sl.category = self.category # Append the the lineage sl.add_to_lineage( category=str(self.category.capitalize()), identifier=self.identifiers[i], columns=strings_as_json(columns), ) # # Create the lineage # sl.lineage = [ # (str(Dataset.__name__), dataset.identifier), # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ), # ] # if isinstance(dataset, Slice): # # Prepend the Slice's lineage instead, if the dataset was a slice # sl.lineage = dataset.lineage + [ # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ) # ] return slices, slice_membership
def stow( dataset: Dataset, cached_ops: Dict[CachedOperation, List[List[str]]], batch_size: int = 32, load_from_cache_file: bool = True, ): """Apply a list of cached operations in sequence.""" # Check the InteractionTape to remove CachedOperations that have already been stowed for cached_op, list_of_columns in list(cached_ops.items()): indices_to_remove = [] for i, columns in enumerate(list(list_of_columns)): if dataset.check_tape( path=[CACHEDOPS], identifiers=cached_op.identifier, columns=columns, ): # Remove the columns at index i indices_to_remove.append(i) # Remove the columns that are already cached for index in sorted(indices_to_remove, reverse=True): columns = cached_ops[cached_op].pop(index) print(f"skipped: {cached_op.identifier} -> {columns}", flush=True) # Check if list_of_columns is now empty if not cached_ops[cached_op]: # Remove the op entirely cached_ops.pop(cached_op) for cached_op, list_of_columns in cached_ops.items(): for columns in list_of_columns: dataset = cached_op(dataset, columns=columns, batch_size=batch_size) # def _map_fn(batch: Batch): # """ # Consolidate the application of the CachedOperations passed to stow into a # single mappable function. # """ # for cached_op, list_of_columns in cached_ops.items(): # for columns in list_of_columns: # batch = cached_op(batch, columns=columns) # # return batch # # # Compute the hash value # val = 0 # for cached_op, list_of_columns in cached_ops.items(): # for columns in list_of_columns: # val ^= cached_op.get_cache_hash(columns=columns) # # # Combine with the hash for the dataset on which the cached ops are applied # val ^= persistent_hash( # # TODO(karan): move this to Dataset # "-".join( # "-".join(str(k) + "-" + str(v) for k, v in f.items()) for f in # dataset._data_files # ) # ) # # # Map the cached operations over the dataset # try: # dataset = dataset.map( # _map_fn, # batched=True, # batch_size=32, # cache_file_name='cache-' + str(abs(val)) + '.arrow', # load_from_cache_file=load_from_cache_file # ) # except TypeError: # # Batch the dataset, and process each batch # all_batches = [_map_fn(batch=batch) for batch in dataset.batch(batch_size)] # # # Update the dataset efficiently by reusing all_batches # dataset = dataset.map( # lambda examples, indices: all_batches[indices[0] // batch_size], # batched=True, # batch_size=batch_size, # with_indices=True, # ) # Update the Dataset history for cached_op, list_of_columns in cached_ops.items(): for columns in list_of_columns: dataset.update_tape(path=[CACHEDOPS], identifiers=cached_op.identifier, columns=columns) return dataset
def process_batch( self, batch: Batch, columns: List[str], mask: List[int] = None, store_compressed: bool = True, store: bool = True, *args, **kwargs, ) -> Tuple[Batch, List[Batch], Optional[np.ndarray]]: # Determine the size of the batch batch_size = len(batch[list(batch.keys())[0]]) # Construct the matrix of slice labels: (batch_size x num_slices) slice_membership = np.ones((batch_size, self.num_slices), dtype=np.int32) # Uncache the batch to construct the skeleton for transformed batches skeleton_batches = [ Dataset.uncached_batch(batch) for _ in range(self.num_slices) ] # Set the index for the skeleton batches for j, skeleton_batch in enumerate(skeleton_batches): skeleton_batch["index"] = [ f"{idx}-{self.identifiers[j]}" for idx in skeleton_batch["index"] ] # Apply the SliceBuilder's core functionality transformed_batches, slice_membership = self.apply( skeleton_batches=skeleton_batches, slice_membership=slice_membership, batch=batch, columns=columns, *args, **kwargs, ) # Store the transformed examples updates = self.construct_updates( transformed_batches=transformed_batches, slice_membership=slice_membership, batch_size=batch_size, columns=columns, mask=mask, compress=store_compressed, ) # Remove transformed examples where slice_membership[i, :] = 0 before returning transformed_batches = [ self.filter_batch_by_slice_membership( batch=transformed_batch, slice_membership=slice_membership[:, j:j + 1])[0] for j, transformed_batch in enumerate(transformed_batches) ] if store: batch = self.store( batch=batch, updates=updates, ) return batch, transformed_batches, slice_membership
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, mask: List[int] = None, store_compressed: bool = True, store: bool = True, num_proc: int = None, *args, **kwargs, ) -> Tuple[Dataset, List[Slice], np.ndarray]: """Apply a SliceBuilder to a dataset. Args: dataset: Dataset columns: list of columns batch_size: integer batch size mask: boolean or integer mask array, mask[i] = True means that the ith slice will be masked out store_compressed: whether to store in a compressed format store: whether to store the results along with the example in Dataset num_proc: num processes for multiprocessing *args: optional additional arguments **kwargs: optional additional keyword arguments Returns: tuple of (Dataset, list of Slices, matrix of (example, slice) membership) """ # Prepare the dataset dataset = self.prepare_dataset( dataset=dataset, columns=columns, batch_size=batch_size, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) # Compute a hash val = persistent_hash(str( dataset.identifier)) ^ dataset.hash_interactions() for i, identifier in enumerate(self.identifiers): if not mask[i]: val ^= persistent_hash( str(identifier) + str(strings_as_json(columns))) try: # Map the SliceBuilder over the dataset all_sliced_batches = [] all_slice_memberships = [] def _map_fn(batch): """Map function for processing batches. Note that using this map_fn in a stateful way is dangerous, since every invocation of this function appends to the all_slice_batches list. The .map() function will invoke this once for testing before performing the map, so we discard the first entry inserted into all_sliced_batches. """ batch, sliced_batches, slice_membership = self.process_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) all_sliced_batches.append(sliced_batches) all_slice_memberships.append(slice_membership) return batch dataset = dataset.map( _map_fn, batched=True, batch_size=batch_size, # FIXME(karan): enable this by adding logic for generating # all_sliced_batches and all_slice_memberships # when loading from cache file load_from_cache_file=False, # The cache file name is a XOR of the interaction history and the # current operation cache_file_name=str(dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")), ) # Remove the first entry (see _map_fn) all_sliced_batches = all_sliced_batches[1:] all_slice_memberships = all_slice_memberships[1:] except: # noqa # Batch the dataset, and process each batch all_batches, all_sliced_batches, all_slice_memberships = zip(*[ self.process_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) for batch in dataset.batch(batch_size) ]) # Update the dataset efficiently by reusing all_batches dataset = dataset.map( lambda examples, indices: all_batches[indices[0] // batch_size ], batched=True, batch_size=batch_size, with_indices=True, load_from_cache_file=False, # The cache file name is a XOR of the interaction history and the # current operation cache_file_name=str(dataset.logdir / ("cache-" + str(abs(val)) + ".arrow")), ) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) slice_cache_hashes = [] for identifier in self.identifiers: slice_cache_hashes.append(val ^ persistent_hash(str(identifier))) if not num_proc or num_proc == 1: # Construct slices slices = [] for i, slice_batches in enumerate(zip(*all_sliced_batches)): slices.append( create_slice(( dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hashes[i], ))) else: # Parallelized slice construction with Pool(num_proc) as pool: slices = pool.map( create_slice, [( dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hashes[i], ) for i, slice_batches in enumerate(zip( *all_sliced_batches))], ) # TODO(karan): make this more systematic # TODO(karan): fix bug when slicing a Slice for i, sl in enumerate(slices): # # Set the Slice features # sl.info.features = dataset.features # Set the Slice category using the SliceBuilder's category sl.category = self.category # Create the lineage sl.lineage = [ (str(Dataset.__name__), dataset.identifier), ( str(self.category.capitalize()), self.identifiers[i], strings_as_json(columns), ), ] if isinstance(dataset, Slice): # Prepend the Slice's lineage instead, if the dataset was a slice sl.lineage = dataset.lineage + [( str(self.category.capitalize()), self.identifiers[i], strings_as_json(columns), )] return dataset, slices, slice_membership
def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, mask: List[int] = None, store_compressed: bool = True, store: bool = True, *args, **kwargs, ) -> Dataset: # Compute the hash for this operation # FIXME(karan): this is repeated inside process_dataset val = persistent_hash(str( dataset.identifier)) ^ dataset.hash_interactions() for i, identifier in enumerate(self.identifiers): if not mask[i]: val ^= persistent_hash( str(identifier) + str(strings_as_json(columns))) try: return dataset.map( partial( self.prepare_batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ), batched=True, batch_size=batch_size, load_from_cache_file=False, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(val)) + "-prep.arrow")), ) except: # TypeError or PicklingError or AttributeError: # noqa # Batch the dataset, and process each batch all_batches = [ self.prepare_batch( batch=batch, columns=columns, mask=mask, store_compressed=store_compressed, store=store, *args, **kwargs, ) for batch in dataset.batch(batch_size) ] # Update the dataset efficiently by reusing all_batches return dataset.map( lambda examples, indices: all_batches[indices[0] // batch_size ], batched=True, batch_size=batch_size, with_indices=True, load_from_cache_file=False, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(val)) + "-prep.arrow")), )
def process_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, # mask: List[int] = None, num_proc: int = None, *args, **kwargs, ) -> Tuple[List[Slice], np.ndarray]: """Apply a SliceBuilder to a dataset. Args: dataset: Dataset columns: list of columns batch_size: integer batch size # mask: boolean or integer mask array, mask[i] = True means that the ith # slice will be masked out num_proc: num processes for multiprocessing *args: optional additional arguments **kwargs: optional additional keyword arguments Returns: tuple of (Dataset, list of Slices, matrix of (example, slice) membership) """ # # Compute a hash # val = persistent_hash(str(dataset.identifier)) ^ dataset.hash_interactions() # for i, identifier in enumerate(self.identifiers): # if not mask[i]: # val ^= persistent_hash(str(identifier) # + str(strings_as_json(columns))) # try: # # Map the SliceBuilder over the dataset # all_sliced_batches = [] # all_slice_memberships = [] # # def _map_fn(batch): # """Map function for processing batches. # # Note that using this map_fn in a stateful way is # dangerous, since every invocation of this function # appends to the all_slice_batches list. The .map() # function will invoke this once for testing before # performing the map, so we discard the first entry # inserted into all_sliced_batches. # """ # batch, sliced_batches, slice_membership = self.process_batch( # batch=batch, # columns=columns, # mask=mask, # store_compressed=store_compressed, # store=store, # *args, # **kwargs, # ) # all_sliced_batches.append(sliced_batches) # all_slice_memberships.append(slice_membership) # return batch # # dataset = dataset.map( # _map_fn, # batched=True, # batch_size=batch_size, # # FIXME(karan): enable this by adding logic for generating # # all_sliced_batches and all_slice_memberships # # when loading from cache file # load_from_cache_file=False, # # The cache file name is a XOR of the interaction history and the # # current operation # cache_file_name=str( # dataset.logdir / ("cache-" + str(abs(val)) + ".arrow") # ), # ) # # # Remove the first entry (see _map_fn) # all_sliced_batches = all_sliced_batches[1:] # all_slice_memberships = all_slice_memberships[1:] # # except: # noqa # all_batches, all_sliced_batches, all_slice_memberships = zip( # *[ # self.process_batch( # batch=batch, # columns=columns, # mask=mask, # store_compressed=store_compressed, # store=store, # *args, # **kwargs, # ) # for batch in dataset.batch(batch_size) # ] # ) # # Update the dataset efficiently by reusing all_batches # dataset = dataset.map( # lambda examples, indices: all_batches[indices[0] // batch_size], # batched=True, # batch_size=batch_size, # with_indices=True, # load_from_cache_file=False, # # The cache file name is a XOR of the interaction history and the # # current operation # cache_file_name=str( # dataset.logdir / ("cache-" + str(abs(val)) + ".arrow") # ), # ) # Create slices slices = [Slice() for _ in range(len(self.identifiers))] all_slice_memberships = [] # Batch the dataset, and process each batch for batch in dataset.batch(batch_size): # Process the batch sliced_batches, slice_memberships = self.process_batch( batch=batch, columns=columns, *args, **kwargs, ) # Incrementally build the slices for sl, sl_batch in zip(slices, sliced_batches): sl._dataset.append(sl_batch) # Keep track of the slice memberships all_slice_memberships.append(slice_memberships) # Create a single slice label matrix slice_membership = np.concatenate(all_slice_memberships, axis=0) # slice_cache_hashes = [] # for identifier in self.identifiers: # slice_cache_hashes.append(val ^ persistent_hash(str(identifier))) # if not num_proc or num_proc == 1: # # Construct slices # slices = [] # for i, slice_batches in enumerate(zip(*all_sliced_batches)): # slices.append( # create_slice( # ( # dataset, # slice_membership, # slice_batches, # i, # batch_size, # slice_cache_hashes[i], # ) # ) # ) # else: # # Parallelized slice construction # with Pool(num_proc) as pool: # slices = pool.map( # create_slice, # [ # ( # dataset, # slice_membership, # slice_batches, # i, # batch_size, # slice_cache_hashes[i], # ) # for i, slice_batches in enumerate(zip(*all_sliced_batches)) # ], # ) for i, sl in enumerate(slices): # Set the Slice category using the SliceBuilder's category sl.category = self.category # Append the the lineage sl.add_to_lineage( category=str(self.category.capitalize()), identifier=self.identifiers[i], columns=strings_as_json(columns), ) # # Create the lineage # sl.lineage = [ # (str(Dataset.__name__), dataset.identifier), # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ), # ] # if isinstance(dataset, Slice): # # Prepend the Slice's lineage instead, if the dataset was a slice # sl.lineage = dataset.lineage + [ # ( # str(self.category.capitalize()), # self.identifiers[i], # strings_as_json(columns), # ) # ] return slices, slice_membership