def _configure_loader(self, loader: Union[DataLoader, tf.data.Dataset]) -> Union[DataLoader, tf.data.Dataset]: """A method to configure a given dataloader for use with this Estimator's Network. This method will ensure that the `loader` returns the correct data type (tf.Tensor or torch.Tensor) depending on the requirements of the Network. It also handles issues with multi-gpu data sharding. Args: loader: A data loader to be modified. Returns: The potentially modified dataloader to be used for training. """ new_loader = loader if isinstance(new_loader, DataLoader) and isinstance(self.network, TFNetwork): add_batch = True if hasattr(loader.dataset, "dataset") and isinstance(loader.dataset.dataset, BatchDataset): add_batch = False batch = to_tensor(loader.dataset[0], target_type="tf") data_type = to_type(batch) data_shape = to_shape(batch, add_batch=add_batch, exact_shape=False) new_loader = tf.data.Dataset.from_generator(lambda: loader, data_type, output_shapes=data_shape) new_loader = new_loader.prefetch(1) if isinstance(new_loader, tf.data.Dataset): if self.system.max_train_steps_per_epoch and self.system.mode == "train": new_loader = new_loader.take(self.system.max_train_steps_per_epoch) if self.system.max_eval_steps_per_epoch and self.system.mode == "eval": new_loader = new_loader.take(self.system.max_eval_steps_per_epoch) if isinstance(tf.distribute.get_strategy(), tf.distribute.MirroredStrategy) and not isinstance(new_loader, DistributedDataset): new_loader = tf.distribute.get_strategy().experimental_distribute_dataset(new_loader) return new_loader
def _configure_loader( self, loader: Union[DataLoader, tf.data.Dataset] ) -> Union[DataLoader, tf.data.Dataset]: """A method to configure a given dataloader for use with this Estimator's Network. This method will ensure that the `loader` returns the correct data type (tf.Tensor or torch.Tensor) depending on the requirements of the Network. It also handles issues with multi-gpu data sharding. Args: loader: A data loader to be modified. Returns: The potentially modified dataloader to be used for training. """ new_loader = loader if isinstance(new_loader, DataLoader) and isinstance( self.network, TFNetwork): add_batch = bool(new_loader.batch_size) if hasattr(loader, 'fe_postprocess_fn' ) and loader.fe_postprocess_fn is not None: # The user is manually batching data and running ops on data batches. No reliable way to shortcut this # since ops might require specific batch composition. data_instance = next(iter(loader)) add_batch = False else: # No batch-based ops so we can try and just use the OpDataset to more quickly get our data summary data_instance = loader.dataset[0] if isinstance(data_instance, list): # This is a batched dataset data_instance = data_instance[0] add_batch = True if isinstance(data_instance, FilteredData): # We got unlucky and drew filtered data as the zeroth element. Fall back to a slower but more robust # analysis of the batch data_instance = next(iter(loader)) add_batch = False data_instance = to_tensor(data_instance, target_type="tf") data_type = to_type(data_instance) data_shape = to_shape(data_instance, add_batch=add_batch, exact_shape=False) new_loader = tf.data.Dataset.from_generator( lambda: loader, data_type, output_shapes=data_shape) new_loader = new_loader.prefetch(1) if isinstance(new_loader, tf.data.Dataset): if self.system.train_steps_per_epoch and self.system.mode == "train": new_loader = new_loader.take(self.system.train_steps_per_epoch) if self.system.eval_steps_per_epoch and self.system.mode == "eval": new_loader = new_loader.take(self.system.eval_steps_per_epoch) if isinstance(tf.distribute.get_strategy(), tf.distribute.MirroredStrategy) and isinstance( self.network, TFNetwork) and not isinstance( new_loader, DistributedDataset): # The default autoshard policy is file, changing it to data to avoid warning options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA new_loader = new_loader.with_options(options) new_loader = tf.distribute.get_strategy( ).experimental_distribute_dataset(new_loader) return new_loader
def __init__(self, model_input: Any, model: Model): self.shape = to_shape(model_input) self.dtype = to_type(model_input) self.tensor_func = tf.ones if isinstance(model, tf.keras.Model) else torch.ones