def test_step_jit( self, x: tp.Any = (), y: tp.Any = None, sample_weight: tp.Optional[np.ndarray] = None, class_weight: tp.Optional[np.ndarray] = None, ): with module.training_context(training=False), module.hooks_context(): return self.test_fn_jit(x, y, sample_weight, class_weight)
def train_step_jit( self, x: tp.Any = (), y: tp.Any = None, sample_weight: tp.Optional[np.ndarray] = None, class_weight: tp.Optional[np.ndarray] = None, ): with module.training_context(training=True), module.hooks_context(): outputs = self.train_fn_jit(x, y, sample_weight, class_weight) return outputs
def train_step( self, x: tp.Any = (), y: tp.Any = None, sample_weight: tp.Optional[np.ndarray] = None, class_weight: tp.Optional[np.ndarray] = None, ) -> tp.Dict[str, tp.Any]: with module.training_context(training=True), module.hooks_context(): return self.train_fn( x=x, y=y, sample_weight=sample_weight, class_weight=class_weight )
def maybe_initialize( self, mode: Mode, x: tp.Union[jnp.ndarray, tp.Mapping[str, tp.Any], tp.Tuple] = (), y: tp.Union[jnp.ndarray, tp.Mapping[str, tp.Any], tp.Tuple, None] = None, sample_weight: tp.Optional[jnp.ndarray] = None, class_weight: tp.Optional[jnp.ndarray] = None, ): with module.init_context(), module.training_context( training=True ), module.hooks_context(): assert self.module is not None if not self.module.initialized: self.predict_fn(x=x) self.module.initialized = True if mode == Mode.predict: return if self.metrics is not None and not self.metrics.initialized: self.test_fn( x=x, y=y, sample_weight=sample_weight, class_weight=class_weight, ) self.metrics.initialized = True self.initial_metrics_state = self.metrics.get_parameters( trainable=False ) if mode == Mode.test: return if self.optimizer is not None and not self.optimizer.initialized: self.train_fn( x=x, y=y, sample_weight=sample_weight, class_weight=class_weight, ) self.optimizer.initialized = True
def test_step( self, x: tp.Any = (), y: tp.Any = None, sample_weight: tp.Optional[np.ndarray] = None, class_weight: tp.Optional[np.ndarray] = None, get_gradients: bool = False, ) -> tp.Tuple[np.ndarray, tp.Dict, tp.Optional[tp.Dict]]: with module.training_context(training=False), module.hooks_context(): return self.test_fn( x=x, y=y, sample_weight=sample_weight, class_weight=class_weight, get_gradients=get_gradients, )
def predict_step_jit(self, x: tp.Any = ()): with module.training_context(training=False): return self.predict_fn_jit(x)