def prepare_dataset( self, dataset: Dataset, columns: List[str], batch_size: int = 32, *args, **kwargs, ) -> None: """Preparation that is applied before the CachedOperation. Many CachedOperations require a full pass over the dataset to precompute some variables before the core operation can actually be applied e.g. to create a Bag-of-Words representation, constructing a dataset vocabulary to keep only tokens that are frequently seen across the dataset. Args: dataset: Dataset columns: list of columns batch_size: batch size for .map(..) Returns: updated Dataset """ # Set the data format dataset.set_format(columns) # Batch the dataset, and prepare each batch for batch in dataset.batch(batch_size): try: # Check if the `prepare_batch` function has been implemented self.prepare_batch( batch=batch, columns=columns, *args, **kwargs, ) except NotImplementedError: break # Reset the data format dataset.reset_format()
def evaluate( self, dataset: Dataset, input_columns: List[str], output_columns: List[str], batch_size: int = 32, metrics: List[str] = None, coerce_fn: Callable = None, ): # TODO(karan): generalize to TF2 # Reset the dataset format dataset.reset_format() dataset.set_format(columns=input_columns + output_columns) # TODO(karan): check that the Dataset conforms to the task definition # TODO(karan): figure out how the output_columns will be used by the metrics pass predictions = [] targets = [] # Loop and apply the prediction function # TODO(karan): not using .map() here in order to get more fine-grained # control over devices for idx in range(0, len(dataset), batch_size): # Create the batch batch = dataset[idx:idx + batch_size] # Predict on the batch prediction_dict = self.predict_batch(batch=batch, input_columns=input_columns) # Coerce the predictions if coerce_fn: prediction_dict = coerce_fn(prediction_dict) # Grab the raw target key/values target_dict = tz.keyfilter(lambda k: k in output_columns, batch) # TODO(karan): general version for non-classification problems # TODO(karan): move this to the right device if self.task.classification(): target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict) # TODO(karan): incremental metric computation here # Append the predictions and targets predictions.append(prediction_dict) targets.append(target_dict) # Consolidate the predictions and targets if self.task.classification(): # TODO(karan): Need to store predictions and outputs from the model predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *predictions) targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets) else: predictions = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *predictions) targets = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *targets) # Compute the metrics # TODO(karan): generalize this code to support metric computation for any task # Assumes classification, so the output_columns contains a single key for the # label if self.task.classification(): assert len( output_columns) == 1 # , "Only supports classification." num_classes = self.task.output_schema.features[list( self.task.output_schema.keys())[0]].num_classes labels = targets[list(targets.keys())[0]] if metrics is None: if self.task is None: raise ValueError( "Must specify metrics if model not associated with task") metrics = self.task.metrics pred = predictions["pred"].to(self.device) target = labels.to(self.device) evaluation_dict = { metric: compute_metric(metric, pred, target, num_classes) for metric in metrics } # Reset the data format dataset.reset_format() return evaluation_dict