def _update_progress_bar(num_iter): _ = lax.cond( num_iter == 0, lambda _: host_callback.id_tap(_define_tqdm, None, result=num_iter ), lambda _: num_iter, operand=None, ) _ = lax.cond( # update tqdm every multiple of `print_rate` except at the end (num_iter % print_rate == 0) & (num_iter != num_samples - remainder), lambda _: host_callback.id_tap( _update_tqdm, print_rate, result=num_iter), lambda _: num_iter, operand=None, ) _ = lax.cond( # update tqdm by `remainder` num_iter == num_samples - remainder, lambda _: host_callback.id_tap( _update_tqdm, remainder, result=num_iter), lambda _: num_iter, operand=None, )
def _update_progress_bar(iter_num): """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate Usage: carry = progress_bar((iter_num, print_rate), carry) """ _ = lax.cond( iter_num == 1, lambda _: host_callback.id_tap( _update_tqdm, 0, result=iter_num, tap_with_device=True), lambda _: iter_num, operand=None, ) _ = lax.cond( iter_num % print_rate == 0, lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num, tap_with_device=True), lambda _: iter_num, operand=None, ) _ = lax.cond( iter_num == num_samples, lambda _: host_callback.id_tap( _close_tqdm, remainder, result=iter_num, tap_with_device=True), lambda _: iter_num, operand=None, )
def _update_progress_bar(iter_num): "Updates tqdm progress bar of a JAX scan or loop" _ = lax.cond( iter_num == 0, lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num ), lambda _: iter_num, operand=None, ) _ = lax.cond( # update tqdm every multiple of `print_rate` except at the end (iter_num % print_rate == 0) & (iter_num != num_samples - remainder), lambda _: host_callback.id_tap( _update_tqdm, print_rate, result=iter_num), lambda _: iter_num, operand=None, ) _ = lax.cond( # update tqdm by `remainder` iter_num == num_samples - remainder, lambda _: host_callback.id_tap( _update_tqdm, remainder, result=iter_num), lambda _: iter_num, operand=None, )
def _step(t_and_dt): _t, _dt = t_and_dt[0], t_and_dt[1] self._input_step(_t=_t, _dt=_dt) for step in self.target.steps.values(): step(_t=_t, _dt=_dt) # id_tap(lambda *args: self._pbar.update(round(self.dt, 4)), ()) id_tap(lambda *args: self._pbar.update(), ()) return self._monitor_step(_t=_t, _dt=_dt)
def chunk_function(chunk_start, wieners_chunk): # Parameters: chunk_start = (t0, x0, w0) values at beggining of chunk # wieners_chunk = array of wiener increments id_tap(tap_func, 0) z = jax.lax.scan( scan_func, chunk_start, wieners_chunk)[0] #discard trajectory at chunk resolution return z, z
def _step(self, t_and_dt): # arguments kwargs = dict() kwargs.update(self.variables) kwargs.update({'t': t_and_dt[0], 'dt': t_and_dt[1]}) kwargs.update(self._static_args) if len(self._dyn_args) > 0: kwargs.update({k: v[self.idx] for k, v in self._dyn_args.items()}) self.idx += 1 # call integrator function update_values = self.target(**kwargs) for i, v in enumerate(self.target.variables): self.variables[v].update(update_values[i]) if self.progress_bar: id_tap(lambda *args: self._pbar.update(), ())
def close_tqdm(result, iter_num): return lax.cond( iter_num == num_samples - 1, lambda _: host_callback.id_tap(_close_tqdm, None, result=result), lambda _: result, operand=None, )
def test_pytree(self, with_jit=False): def func(x, what=""): """Returns some pytrees depending on x""" if what == "pair_1_x": return (1, x) elif what == "pair_x_2x": return (x, 2 * x) elif what == "dict": return dict(a=2 * x, b=3 * x) else: assert False tap_count = 0 def tap_func(a, what=""): nonlocal tap_count tap_count += 1 self.assertEqual(func(5, what), a) transform = api.jit if with_jit else lambda f: f with hcb.outfeed_receiver(receiver_name=self._testMethodName): for what in ("pair_1_x", "pair_x_2x", "dict"): self.assertEqual( func(10, what), transform(lambda x: hcb.id_tap(tap_func, func(x, what), result=func(x * 2, what), what=what))(5)) # Wait for receivers to be done self.assertEqual(3, tap_count)
def _close_controller(cond, inputs): jax.lax.cond( cond, lambda _: inputs, lambda _: id_tap( _close_pbar, inputs), operand=None) return cond
def _hcb_print( self, string_from_args: Callable[..., str], *args: hints.Pytree, **kwargs: hints.Pytree, ) -> None: """Helper for printer optimizer messages via host callbacks. No-op if `verbose` is set to `False`.""" if not self.verbose: return hcb.id_tap( lambda args_kwargs, _unused_transforms: print( f"[{type(self).__name__}]", string_from_args(*args_kwargs[0], **args_kwargs[1]), ), (args, kwargs), )
def _update_controller(args): counter = args[0] jax.lax.cond( np.logical_and( np.logical_and( np.greater(counter, 0), np.equal(counter % print_rate, 0)), np.not_equal(counter, max_iterations - remainder)), lambda _: id_tap( _update_pbar, (print_rate,) + args), lambda _: (print_rate,) + args, operand=None) jax.lax.cond( np.equal(counter, max_iterations - remainder), lambda _: id_tap( _update_pbar, (remainder,) + args), lambda _: (remainder,) + args, operand=None)
def _step(state, i, *args): state = lax.cond( i % self.log_freq == 0, lambda _: host_callback.id_tap(self.log_func, (i, self.num_epochs), result=state), lambda _: state, operand=None, ) return self.svi.update(state, *args)
def _update_progress_bar(iter_num): """Updates tqdm progress bar of a scan/loop only if the iteration number is a multiple of the print_rate """ _ = lax.cond( (iter_num % print_rate == 0) & (iter_num != num_samples - remainder), lambda _: host_callback.id_tap( _update_tqdm, print_rate, result=iter_num), lambda _: iter_num, operand=None, ) _ = lax.cond( iter_num == num_samples - remainder, lambda _: host_callback.id_tap( _update_tqdm, remainder, result=iter_num), lambda _: iter_num, operand=None, )
def f(augmented: TheAugmentedState, x: None) -> Tuple[TheAugmentedState, Trajectory]: trajectory: Trajectory new_state, trajectory = self.sampled_state_trajectory( theta, augmented) new_augmented = self.iterate_augmented(new_state, augmented) if tap_function is not None: trajectory = id_tap( tap_function, # type: ignore[no-untyped-call] None, result=trajectory) return new_augmented, trajectory
def id_display(x: _T, name: Optional[str] = None, *, no_jvp: bool = False) -> _T: def tap(x: _T, transforms: TapFunctionTransforms) -> None: nonlocal name batch_dims: Optional[Tuple[Optional[int], ...]] = None flags = [] for transform_name, transform_dict in transforms: if transform_name == 'batch': batch_dims = transform_dict['batch_dims'] continue if no_jvp and transform_name == 'jvp': return if transform_name in ['jvp', 'mask', 'transpose']: flags.append(transform_name) continue if name is None: print_generic(x, batch_dims=batch_dims) else: if flags: final_name = name + f" [{', '.join(flags)}]" else: final_name = name # https://github.com/python/mypy/issues/11583 print_generic(batch_dims=batch_dims, **{final_name: x}) # type: ignore[arg-type] return id_tap(tap, x, result=x) # type: ignore[no-untyped-call]
def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(hcb._end_consumer, result=x1 + 1) # Will end the consumer loop x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3
def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3
def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3
def func(x, count): for i in range(count): x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1] return x
def do_tap(idx): api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
def nlls_id_print(it, x, end="\n"): printer = partial(nlls_printout, end=end) return id_tap(printer, (it, x))
def long_run(x): return hcb.id_tap(pause_tap, x)