Exemple #1
0
    def _update_progress_bar(num_iter):
        _ = lax.cond(
            num_iter == 0,
            lambda _: host_callback.id_tap(_define_tqdm, None, result=num_iter
                                           ),
            lambda _: num_iter,
            operand=None,
        )

        _ = lax.cond(
            # update tqdm every multiple of `print_rate` except at the end
            (num_iter % print_rate == 0) &
            (num_iter != num_samples - remainder),
            lambda _: host_callback.id_tap(
                _update_tqdm, print_rate, result=num_iter),
            lambda _: num_iter,
            operand=None,
        )

        _ = lax.cond(
            # update tqdm by `remainder`
            num_iter == num_samples - remainder,
            lambda _: host_callback.id_tap(
                _update_tqdm, remainder, result=num_iter),
            lambda _: num_iter,
            operand=None,
        )
Exemple #2
0
    def _update_progress_bar(iter_num):
        """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
        Usage: carry = progress_bar((iter_num, print_rate), carry)
        """

        _ = lax.cond(
            iter_num == 1,
            lambda _: host_callback.id_tap(
                _update_tqdm, 0, result=iter_num, tap_with_device=True),
            lambda _: iter_num,
            operand=None,
        )
        _ = lax.cond(
            iter_num % print_rate == 0,
            lambda _: host_callback.id_tap(_update_tqdm,
                                           print_rate,
                                           result=iter_num,
                                           tap_with_device=True),
            lambda _: iter_num,
            operand=None,
        )
        _ = lax.cond(
            iter_num == num_samples,
            lambda _: host_callback.id_tap(
                _close_tqdm, remainder, result=iter_num, tap_with_device=True),
            lambda _: iter_num,
            operand=None,
        )
    def _update_progress_bar(iter_num):
        "Updates tqdm progress bar of a JAX scan or loop"
        _ = lax.cond(
            iter_num == 0,
            lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num
                                           ),
            lambda _: iter_num,
            operand=None,
        )

        _ = lax.cond(
            # update tqdm every multiple of `print_rate` except at the end
            (iter_num % print_rate == 0) &
            (iter_num != num_samples - remainder),
            lambda _: host_callback.id_tap(
                _update_tqdm, print_rate, result=iter_num),
            lambda _: iter_num,
            operand=None,
        )

        _ = lax.cond(
            # update tqdm by `remainder`
            iter_num == num_samples - remainder,
            lambda _: host_callback.id_tap(
                _update_tqdm, remainder, result=iter_num),
            lambda _: iter_num,
            operand=None,
        )
Exemple #4
0
 def _step(t_and_dt):
     _t, _dt = t_and_dt[0], t_and_dt[1]
     self._input_step(_t=_t, _dt=_dt)
     for step in self.target.steps.values():
         step(_t=_t, _dt=_dt)
     # id_tap(lambda *args: self._pbar.update(round(self.dt, 4)), ())
     id_tap(lambda *args: self._pbar.update(), ())
     return self._monitor_step(_t=_t, _dt=_dt)
Exemple #5
0
 def chunk_function(chunk_start, wieners_chunk):
     # Parameters: chunk_start = (t0, x0, w0) values at beggining of chunk
     #             wieners_chunk = array of wiener increments
     id_tap(tap_func, 0)
     z = jax.lax.scan(
         scan_func, chunk_start,
         wieners_chunk)[0]  #discard trajectory at chunk resolution
     return z, z
Exemple #6
0
 def _step(self, t_and_dt):
     # arguments
     kwargs = dict()
     kwargs.update(self.variables)
     kwargs.update({'t': t_and_dt[0], 'dt': t_and_dt[1]})
     kwargs.update(self._static_args)
     if len(self._dyn_args) > 0:
         kwargs.update({k: v[self.idx] for k, v in self._dyn_args.items()})
         self.idx += 1
     # call integrator function
     update_values = self.target(**kwargs)
     for i, v in enumerate(self.target.variables):
         self.variables[v].update(update_values[i])
     if self.progress_bar:
         id_tap(lambda *args: self._pbar.update(), ())
 def close_tqdm(result, iter_num):
     return lax.cond(
         iter_num == num_samples - 1,
         lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
         lambda _: result,
         operand=None,
     )
Exemple #8
0
    def test_pytree(self, with_jit=False):
        def func(x, what=""):
            """Returns some pytrees depending on x"""
            if what == "pair_1_x":
                return (1, x)
            elif what == "pair_x_2x":
                return (x, 2 * x)
            elif what == "dict":
                return dict(a=2 * x, b=3 * x)
            else:
                assert False

        tap_count = 0

        def tap_func(a, what=""):
            nonlocal tap_count
            tap_count += 1
            self.assertEqual(func(5, what), a)

        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            for what in ("pair_1_x", "pair_x_2x", "dict"):
                self.assertEqual(
                    func(10, what),
                    transform(lambda x: hcb.id_tap(tap_func,
                                                   func(x, what),
                                                   result=func(x * 2, what),
                                                   what=what))(5))
        # Wait for receivers to be done
        self.assertEqual(3, tap_count)
Exemple #9
0
 def _close_controller(cond, inputs):
     jax.lax.cond(
         cond,
         lambda _: inputs,
         lambda _: id_tap(
             _close_pbar, inputs),
         operand=None)
     return cond
    def _hcb_print(
        self,
        string_from_args: Callable[..., str],
        *args: hints.Pytree,
        **kwargs: hints.Pytree,
    ) -> None:
        """Helper for printer optimizer messages via host callbacks. No-op if `verbose`
        is set to `False`."""

        if not self.verbose:
            return

        hcb.id_tap(
            lambda args_kwargs, _unused_transforms: print(
                f"[{type(self).__name__}]",
                string_from_args(*args_kwargs[0], **args_kwargs[1]),
            ),
            (args, kwargs),
        )
Exemple #11
0
    def _update_controller(args):
        counter = args[0]
        jax.lax.cond(
            np.logical_and(
                np.logical_and(
                    np.greater(counter, 0),
                    np.equal(counter % print_rate, 0)),
                np.not_equal(counter, max_iterations - remainder)),
            lambda _: id_tap(
                _update_pbar, (print_rate,) + args),
            lambda _: (print_rate,) + args,
            operand=None)

        jax.lax.cond(
            np.equal(counter, max_iterations - remainder),
            lambda _: id_tap(
                _update_pbar, (remainder,) + args),
            lambda _: (remainder,) + args,
            operand=None)
Exemple #12
0
 def _step(state, i, *args):
     state = lax.cond(
         i % self.log_freq == 0,
         lambda _: host_callback.id_tap(self.log_func,
                                        (i, self.num_epochs),
                                        result=state),
         lambda _: state,
         operand=None,
     )
     return self.svi.update(state, *args)
Exemple #13
0
    def _update_progress_bar(iter_num):
        """Updates tqdm progress bar of a scan/loop only if the iteration
        number is a multiple of the print_rate
        """
        _ = lax.cond(
            (iter_num % print_rate == 0) &
            (iter_num != num_samples - remainder),
            lambda _: host_callback.id_tap(
                _update_tqdm, print_rate, result=iter_num),
            lambda _: iter_num,
            operand=None,
        )

        _ = lax.cond(
            iter_num == num_samples - remainder,
            lambda _: host_callback.id_tap(
                _update_tqdm, remainder, result=iter_num),
            lambda _: iter_num,
            operand=None,
        )
Exemple #14
0
 def f(augmented: TheAugmentedState,
       x: None) -> Tuple[TheAugmentedState, Trajectory]:
     trajectory: Trajectory
     new_state, trajectory = self.sampled_state_trajectory(
         theta, augmented)
     new_augmented = self.iterate_augmented(new_state, augmented)
     if tap_function is not None:
         trajectory = id_tap(
             tap_function,  # type: ignore[no-untyped-call]
             None,
             result=trajectory)
     return new_augmented, trajectory
Exemple #15
0
def id_display(x: _T, name: Optional[str] = None, *, no_jvp: bool = False) -> _T:
    def tap(x: _T, transforms: TapFunctionTransforms) -> None:
        nonlocal name
        batch_dims: Optional[Tuple[Optional[int], ...]] = None
        flags = []
        for transform_name, transform_dict in transforms:
            if transform_name == 'batch':
                batch_dims = transform_dict['batch_dims']
                continue
            if no_jvp and transform_name == 'jvp':
                return
            if transform_name in ['jvp', 'mask', 'transpose']:
                flags.append(transform_name)
                continue
        if name is None:
            print_generic(x, batch_dims=batch_dims)
        else:
            if flags:
                final_name = name + f" [{', '.join(flags)}]"
            else:
                final_name = name
            # https://github.com/python/mypy/issues/11583
            print_generic(batch_dims=batch_dims, **{final_name: x})  # type: ignore[arg-type]
    return id_tap(tap, x, result=x)  # type: ignore[no-untyped-call]
Exemple #16
0
 def func(x):
     x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
     x2 = hcb.id_tap(hcb._end_consumer,
                     result=x1 + 1)  # Will end the consumer loop
     x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
     return x3
Exemple #17
0
 def func(x):
     x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
     x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
     x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
     return x3
Exemple #18
0
 def func(x):
     x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
     x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
     x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
     return x3
Exemple #19
0
 def func(x, count):
     for i in range(count):
         x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)],
                        i=i)[-1]
     return x
Exemple #20
0
 def do_tap(idx):
     api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
Exemple #21
0
def nlls_id_print(it, x, end="\n"):
    printer = partial(nlls_printout, end=end)
    return id_tap(printer, (it, x))
 def long_run(x):
     return hcb.id_tap(pause_tap, x)