Exemple #1
0
    def update_batch_idx(self, new_batch_idx: int) -> None:
        if not self.is_enabled:
            return

        check.check_gt_eq(new_batch_idx, self.current_batch_idx,
                          "Batch index should never decrease over time")
        self.current_batch_idx = new_batch_idx
        self.sys_metric_collector_thread.update_batch_idx(
            self.current_batch_idx)

        # Check if we should start collecting metrics
        if not self.has_started and self.current_batch_idx >= self.start_on_batch:
            self._begin_collection()

        # Check if we should stop collecting metrics due to batch idx being exceeded
        if (self.is_active and self.end_after_batch is not None
                and self.current_batch_idx > self.end_after_batch):
            self._end_collection()
    def update_batch_idx(self, new_batch_idx: int) -> None:
        if not self.is_enabled:
            return

        check.check_gt_eq(new_batch_idx, self.current_batch_idx,
                          "Batch index should never decrease over time")
        self.current_batch_idx = new_batch_idx

        if self.sysmetrics_is_enabled:
            self.sys_metric_collector_thread.update_batch_idx(
                self.current_batch_idx)

        if self.timings_is_enabled:
            self.metrics_batcher_queue.put(FinalizeBatchMessage())

        # Check if we should start collecting metrics
        if not self.has_started and self.current_batch_idx >= self.begin_on_batch:
            self._begin_collection()

        # Check if we should stop collecting metrics due to batch idx being exceeded
        if (self.is_active and self.end_after_batch is not None
                and self.current_batch_idx > self.end_after_batch):
            self._end_collection()
            self.shutdown_timer.send_shutdown_signal()
Exemple #3
0
    def __init__(
        self,
        x: ArrayLike,
        y: ArrayLike,
        batch_size: int,
        sample_weights: Optional[np.ndarray] = None,
        drop_leftovers: bool = False,
    ):
        """
        If converting numpy array data to Sequence to optimize performance, consider
        using ArrayLikeAdapter.

        Args:
            x: Input data. It could be:
                1) A Numpy array (or array-like), or a list of arrays (in case the model
                has multiple inputs).
                2) A dict mapping input names to the corresponding array, if the model
                has named inputs.

            y: Target data. Like the input data x, it could be either Numpy array(s).

            batch_size: Number of samples per batch.

            sample_weights: Numpy array of weights for the samples.

            drop_leftovers: If True, drop the data that cannot complete the last batch. This
                argument is ignored if x is a Sequence or a Dataset.
        """

        if not (isinstance(x, np.ndarray) or _is_list_of_numpy_array(x)
                or _is_dict_of_numpy_array(x)):
            raise det.errors.InvalidDataTypeException(
                type(x),
                "Data which is not tf.data.Datasets or tf.keras.utils.Sequence objects must be a "
                "numpy array or a list/dict of numpy arrays. See the instructions below for "
                f"details:\n{keras.TFKerasTrial.build_training_data_loader.__doc__}",
            )
        if not (isinstance(y, np.ndarray) or _is_list_of_numpy_array(y)
                or _is_dict_of_numpy_array(y)):
            raise det.errors.InvalidDataTypeException(
                type(y),
                "Data which is not tf.data.Datasets or tf.keras.utils.Sequence objects must be a "
                "numpy array or a list/dict of numpy arrays. See the instructions below for "
                f"details:\n{keras.TFKerasTrial.build_training_data_loader.__doc__}",
            )

        self._x_length = _length_of_multi_arraylike(x)
        self._y_length = _length_of_multi_arraylike(y)

        check.eq(self._x_length, self._y_length,
                 "Length of x and y do not match.")
        check.check_gt_eq(self._x_length, batch_size,
                          "Batch size is too large for the input data.")
        if sample_weights is not None:
            check.eq(
                self._x_length,
                len(sample_weights),
                "Lengths of input data and sample weights do not match.",
            )

        self.x = x
        self.y = y
        self.sample_weight = sample_weights

        self.batch_size = batch_size
        self.drop_leftovers = drop_leftovers