Пример #1
0
 def test_step_jit(
     self,
     x: tp.Any = (),
     y: tp.Any = None,
     sample_weight: tp.Optional[np.ndarray] = None,
     class_weight: tp.Optional[np.ndarray] = None,
 ):
     with module.training_context(training=False), module.hooks_context():
         return self.test_fn_jit(x, y, sample_weight, class_weight)
Пример #2
0
    def train_step_jit(
        self,
        x: tp.Any = (),
        y: tp.Any = None,
        sample_weight: tp.Optional[np.ndarray] = None,
        class_weight: tp.Optional[np.ndarray] = None,
    ):
        with module.training_context(training=True), module.hooks_context():
            outputs = self.train_fn_jit(x, y, sample_weight, class_weight)

        return outputs
Пример #3
0
    def train_step(
        self,
        x: tp.Any = (),
        y: tp.Any = None,
        sample_weight: tp.Optional[np.ndarray] = None,
        class_weight: tp.Optional[np.ndarray] = None,
    ) -> tp.Dict[str, tp.Any]:

        with module.training_context(training=True), module.hooks_context():
            return self.train_fn(
                x=x, y=y, sample_weight=sample_weight, class_weight=class_weight
            )
Пример #4
0
    def maybe_initialize(
        self,
        mode: Mode,
        x: tp.Union[jnp.ndarray, tp.Mapping[str, tp.Any], tp.Tuple] = (),
        y: tp.Union[jnp.ndarray, tp.Mapping[str, tp.Any], tp.Tuple, None] = None,
        sample_weight: tp.Optional[jnp.ndarray] = None,
        class_weight: tp.Optional[jnp.ndarray] = None,
    ):

        with module.init_context(), module.training_context(
            training=True
        ), module.hooks_context():
            assert self.module is not None

            if not self.module.initialized:
                self.predict_fn(x=x)
                self.module.initialized = True

            if mode == Mode.predict:
                return

            if self.metrics is not None and not self.metrics.initialized:

                self.test_fn(
                    x=x,
                    y=y,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                )
                self.metrics.initialized = True

                self.initial_metrics_state = self.metrics.get_parameters(
                    trainable=False
                )

            if mode == Mode.test:
                return

            if self.optimizer is not None and not self.optimizer.initialized:
                self.train_fn(
                    x=x,
                    y=y,
                    sample_weight=sample_weight,
                    class_weight=class_weight,
                )
                self.optimizer.initialized = True
Пример #5
0
    def test_step(
        self,
        x: tp.Any = (),
        y: tp.Any = None,
        sample_weight: tp.Optional[np.ndarray] = None,
        class_weight: tp.Optional[np.ndarray] = None,
        get_gradients: bool = False,
    ) -> tp.Tuple[np.ndarray, tp.Dict, tp.Optional[tp.Dict]]:

        with module.training_context(training=False), module.hooks_context():
            return self.test_fn(
                x=x,
                y=y,
                sample_weight=sample_weight,
                class_weight=class_weight,
                get_gradients=get_gradients,
            )
Пример #6
0
 def predict_step_jit(self, x: tp.Any = ()):
     with module.training_context(training=False):
         return self.predict_fn_jit(x)