def transform(self, data: Dict[str, Any], mode: str, epoch: int = 1) -> Dict[str, Any]: """Run a forward step through the Network on an element of data. Args: data: The element to data to use as input. mode: The mode in which to run the transform. One of 'train', 'eval', 'test', or 'infer'. epoch: The epoch in which to run the transform. Returns: (batch_data, prediction_data) """ # Distribute multi-gpu data for processing sub_sample = False strategy = tf.distribute.get_strategy() if isinstance(strategy, tf.distribute.MirroredStrategy): batch_size, num_devices = get_batch_size( data), strategy.num_replicas_in_sync if batch_size < num_devices: data = self._fill_batch(data, num_devices - batch_size) sub_sample = True data = next( iter( strategy.experimental_distribute_dataset( tf.data.Dataset.from_tensors(data)))) results = super().transform(data, mode, epoch) if sub_sample: results = self._subsample_data(results, batch_size) return results
def transform(self, data: Dict[str, Any], mode: str, epoch: int = 1) -> Dict[str, Any]: """Run a forward step through the Network on an element of data. Args: data: The element to data to use as input. mode: The mode in which to run the transform. One of 'train', 'eval', 'test', or 'infer'. epoch: The epoch in which to run the transform. Returns: (batch_data, prediction_data) """ self.load_epoch(mode, epoch, warmup=False) data = to_tensor(data, target_type="tf") data, prediction = self.run_step(data) self.unload_epoch() # handle tensorflow multi-gpu inferencing issue, it will replicate data on each device if isinstance(tf.distribute.get_strategy(), tf.distribute.MirroredStrategy): prediction = self._subsample_data(prediction, get_batch_size(data)) data.update(prediction) return data