def jax_fn(flat_states, inputs): states = jax.tree_unflatten(states_def, flat_states) y_pred, _ = utils.inject_dependencies(self.pred_step)( x=inputs, states=states, initializing=False, training=False) return y_pred
def call_train_step( self, x: tp.Any, y_true: tp.Any, mode: types.Mode, sample_weight: tp.Optional[np.ndarray], class_weight: tp.Optional[np.ndarray], states: types.States, initializing: bool, ) -> TrainStep: return utils.inject_dependencies(self.train_step)( x=x, y_true=y_true, mode=mode, net_params=states.net_params, net_states=states.net_states, metrics_states=states.metrics_states, optimizer_states=states.optimizer_states, sample_weight=sample_weight, class_weight=class_weight, rng=states.rng, states=states, initializing=initializing, training=(mode == types.Mode.train), )
def apply_recursive(context: tp.Tuple[str, ...], metrics, **kwargs): if isinstance(metrics, tp.Callable): name = ( metrics.name if isinstance(metrics, module.Module) else utils.lower_snake_case(metrics.__name__) ) context += (name,) value = utils.inject_dependencies(metrics)(**kwargs) if isinstance(value, tp.Dict): for name, value in value.items(): yield context + (name,), value else: yield context, value elif isinstance(metrics, (tp.Tuple, tp.List)): for loss in metrics: yield from apply_recursive(context, loss, **kwargs) elif isinstance(metrics, tp.Dict): for name, loss in metrics.items(): yield from apply_recursive(context + (name,), loss, **kwargs) else: raise TypeError(f"Invalid type {type(metrics)}")
def apply_recursive(self, context: tp.Tuple[str, ...], losses, **kwargs): if isinstance(losses, tp.Callable): name = ( losses.name if isinstance(losses, Loss) else utils.lower_snake_case(losses.__name__) ) context += (name,) val = utils.inject_dependencies(losses)(**kwargs) if isinstance(val, tp.Dict): for name, val in val.items(): yield context + (name,), val else: yield context, val elif isinstance(losses, (tp.Tuple, tp.List)): for loss in losses: yield from self.apply_recursive(context, loss, **kwargs) elif isinstance(losses, tp.Dict): for name, loss in losses.items(): yield from self.apply_recursive(context + (name,), loss, **kwargs) else: raise TypeError(f"Invalid type {type(losses)}")
def call_pred_step( self, x: tp.Any, mode: types.Mode, states: types.States, initializing: bool, ) -> PredStep: get_losses_and_metrics = mode in (types.Mode.test, types.Mode.train) get_summaries = mode == types.Mode.summary with hooks.context( losses=get_losses_and_metrics, metrics=get_losses_and_metrics, summaries=get_summaries, ): return utils.inject_dependencies(self.pred_step)( x=x, mode=mode, net_params=states.net_params, net_states=states.net_states, rng=states.rng, states=states, initializing=initializing, training=(mode == types.Mode.train), )
def __init__( self, reduction: tp.Optional[Reduction] = None, name: tp.Optional[str] = None, weight: tp.Optional[float] = None, on: tp.Optional[types.IndexLike] = None, ): """ Initializes `Loss` class. Arguments: reduction: (Optional) Type of `elegy.losses.Reduction` to apply to loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When used with `tf.distribute.Strategy`, outside of built-in training loops such as `elegy` `compile` and `fit`, or `SUM_OVER_BATCH_SIZE` will raise an error. for more details. name: Optional name for the loss. weight: Optional weight contribution for the total loss. Defaults to `1`. """ self.name = (name if name is not None else re.sub( r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()) self.weight = weight if weight is not None else 1.0 self._reduction = (reduction if reduction is not None else Reduction.SUM_OVER_BATCH_SIZE) self._labels_filter = (on, ) if isinstance(on, (str, int)) else on self.call = utils.inject_dependencies(self.call)
def test_positional_error_remaining(self): def f(a, b, c): return a + b + c g = utils.inject_dependencies(f) with pytest.raises(TypeError): g("a", "b", "c", "d")
def test_keyword_error_missing(self): def f(a, b, c): return a + b + c g = utils.inject_dependencies(f) with pytest.raises(TypeError): g(b="b", c="c")
def test_positional_error_missing(self): def f(a, b, c): return a + b + c g = inject_dependencies(f) with pytest.raises(TypeError): g("a", "b")
def test_keyword(self): def f(a, b, c): return a + b + c g = inject_dependencies(f) y = g(b="b", c="c", a="a") assert y == "abc"
def test_positional(self): def f(a, b, c): return a + b + c g = utils.inject_dependencies(f) y = g("a", "b", "c") assert y == "abc"
def call_summary_step( self, x: tp.Any, states: types.States, ) -> tp.List[types.SummaryTableEntry]: return utils.inject_dependencies(self.summary_step)( x=x, states=states, )
def _lambda(*args, **kwargs): y_pred, parameters, collections = utils.inject_dependencies( self.module.init(rng=rng))( *args, **kwargs, ) return types.OutputStates(y_pred, parameters, collections)
def test_keyword_extras_ok(self): def f(a, b, c): return a + b + c g = utils.inject_dependencies(f) y = g(b="b", c="c", a="a", d="d") assert y == "abc"
def test_positional_extras_ok(self): def f(a, b, c): return a + b + c g = inject_dependencies(f) y = g("a", "b", "c", d="d") assert y == "abc"
def test_mixed(self): def f(a, b, c): return a + b + c g = utils.inject_dependencies(f) y = g("a", c="c", b="b") assert y == "abc"
def test_mixed_ignore_duplicated_kwarg_in_arg(self): def f(a, b, c): return a + b + c g = utils.inject_dependencies(f) y = g("a", c="c", b="b", a="f") assert y == "abc"
def test_override_defaults(self): def f(a, b, c="x"): return a + b + c g = utils.inject_dependencies(f) y = g("a", c="c", b="b") assert y == "abc"
def _lambda(*args, **kwargs): y_pred, net_params, net_states = utils.inject_dependencies( self.module.apply(params, states, training=training, rng=rng), )( *args, **kwargs, ) return types.OutputStates(y_pred, net_params, net_states)
def __init__( self, name: tp.Optional[str] = None, dtype: tp.Optional[jnp.dtype] = None, on: tp.Optional[types.IndexLike] = None, ): super().__init__(name=name) self._dtype = self._dtype = dtype if dtype is not None else jnp.float32 self._labels_filter = (on, ) if isinstance(on, (str, int)) else on self.call = utils.inject_dependencies(self.call)
def _lambda(*args, **kwargs): def apply_fn(*args, **kwargs): kwargs = {f"__{name}": value for name, value in kwargs.items()} return self.module.apply(params, states, rng.next(), *args, **kwargs) y_pred, states_ = utils.inject_dependencies(apply_fn, signature_f=self.f,)( *args, **kwargs, ) return types.OutputStates(y_pred, params, states_)
def lambda_(*args, **kwargs) -> types.OutputStates: output = utils.inject_dependencies(self.f)(*args, **kwargs) if isinstance(output, types.OutputStates): return output else: return types.OutputStates( preds=output, params=types.UNINITIALIZED, states=types.UNINITIALIZED, )
def call_pred_step( self, x: tp.Any, states: types.States, initializing: bool, training: bool, ) -> PredStep: return utils.inject_dependencies(self.pred_step)( x=x, states=states, initializing=initializing, training=training, )
def _lambda(*args, **kwargs): y_pred, collections = utils.inject_dependencies( self.module.init(rng=rng))( *args, **kwargs, ) assert isinstance(collections, dict) net_params = collections.pop("parameters", {}) net_states = collections return types.OutputStates(y_pred, net_params, net_states)
def calculate_losses(self, *args, **kwargs) -> types.Logs: logs: types.Logs = {} for name, loss_fn in self.losses.items(): losses = utils.inject_dependencies(loss_fn)(*args, **kwargs) names = set() for inner_name, loss in utils.flatten_names(losses): inner_name = f"{name}/{inner_name}" if inner_name else name inner_name = utils.get_unique_name(names, inner_name) logs[inner_name] = loss return logs
def _lambda(*args, **kwargs) -> types.OutputStates: preds = utils.inject_dependencies(self.f)(*args, **kwargs) if isinstance(preds, types.OutputStates): return preds n = 0 total = jax.tree_map(lambda x: jnp.zeros_like(x), preds) return types.OutputStates( preds=preds, params=None, states=(n, total), )
def call_init_step( self, x: tp.Any, y_true: tp.Any, sample_weight: tp.Optional[np.ndarray], class_weight: tp.Optional[np.ndarray], states: types.States, ) -> types.States: return utils.inject_dependencies(self.init_step)( x=x, y_true=y_true, sample_weight=sample_weight, class_weight=class_weight, states=states, )
def _lambda(*args, **kwargs): collections = states.copy() if states is not None else {} if params is not None: collections["parameters"] = params y_pred, collections = utils.inject_dependencies( self.module.apply(collections, training=training, rng=rng), )( *args, **kwargs, ) assert isinstance(collections, dict) net_params = collections.pop("parameters", {}) net_states = collections return types.OutputStates(y_pred, net_params, net_states)
def _lambda(*args, **kwargs): def init_fn(*args, **kwargs) -> types.OutputStates: kwargs = {f"__{name}": value for name, value in kwargs.items()} key = rng.next() y_pred, params, states = self.module.init(key, *args, **kwargs) return types.OutputStates(y_pred, params, states) y_pred, params, states = utils.inject_dependencies( init_fn, signature_f=self.f, )( *args, **kwargs, ) return types.OutputStates(y_pred, params, states)
def _lambda(*args, **kwargs) -> types.OutputStates: preds = utils.inject_dependencies(self.f)(*args, **kwargs) if isinstance(preds, types.OutputStates): return preds n, total = states n += 1 total = jax.tree_multimap(lambda a, b: a + b, preds, total) preds = jax.tree_map(lambda total: total / n, total) return types.OutputStates( preds=preds, params=None, states=(n, total), )