Example #1
0
    def _process_scan(self, args, kwargs):
        jaxpr = kwargs['jaxpr']

        is_cell_parametrized = (len(jaxpr.jaxpr.eqns) == 1 and isinstance(
            jaxpr.jaxpr.eqns[0].primitive, parametrized))

        if not is_cell_parametrized:
            return _scan_impl(*args, **kwargs)

        eqn, = jaxpr.jaxpr.eqns
        flat_cell = partial(self.process_parametrized, eqn.primitive,
                            **eqn.params)
        return _custom_cell_scan_impl(flat_cell, *args, **kwargs)
Example #2
0
 def func1(*jax_args):
     return lax_control_flow._scan_impl(*jax_args, **kwargs)