Esempio n. 1
0
    def transform(self,
                  data: Dict[str, Any],
                  mode: str,
                  epoch: int = 1) -> Dict[str, Any]:
        """Run a forward step through the Network on an element of data.

        Args:
            data: The element to data to use as input.
            mode: The mode in which to run the transform. One of 'train', 'eval', 'test', or 'infer'.
            epoch: The epoch in which to run the transform.

        Returns:
            (batch_data, prediction_data)
        """
        # Distribute multi-gpu data for processing
        sub_sample = False
        strategy = tf.distribute.get_strategy()
        if isinstance(strategy, tf.distribute.MirroredStrategy):
            batch_size, num_devices = get_batch_size(
                data), strategy.num_replicas_in_sync
            if batch_size < num_devices:
                data = self._fill_batch(data, num_devices - batch_size)
                sub_sample = True
            data = next(
                iter(
                    strategy.experimental_distribute_dataset(
                        tf.data.Dataset.from_tensors(data))))
        results = super().transform(data, mode, epoch)
        if sub_sample:
            results = self._subsample_data(results, batch_size)
        return results
Esempio n. 2
0
    def transform(self, data: Dict[str, Any], mode: str, epoch: int = 1) -> Dict[str, Any]:
        """Run a forward step through the Network on an element of data.

        Args:
            data: The element to data to use as input.
            mode: The mode in which to run the transform. One of 'train', 'eval', 'test', or 'infer'.
            epoch: The epoch in which to run the transform.

        Returns:
            (batch_data, prediction_data)
        """
        self.load_epoch(mode, epoch, warmup=False)
        data = to_tensor(data, target_type="tf")
        data, prediction = self.run_step(data)
        self.unload_epoch()
        # handle tensorflow multi-gpu inferencing issue, it will replicate data on each device
        if isinstance(tf.distribute.get_strategy(), tf.distribute.MirroredStrategy):
            prediction = self._subsample_data(prediction, get_batch_size(data))
        data.update(prediction)
        return data