def test_loop_over_tt(self, variant): dt = 0.03 horizon = 29 waveform = BreathWaveform.create() expiratory = Expiratory.create(waveform=waveform) controller_state = self.controller.init() expiratory_state = expiratory.init() state, obs = self.sim.reset() """jit_loop_over_tt = jax.jit( functools.partial( loop_over_tt, controller=self.controller, expiratory=expiratory, env=self.sim, dt=dt)) """ variant_loop_over_tt = variant( functools.partial(loop_over_tt, controller=self.controller, expiratory=expiratory, env=self.sim, dt=dt)) _, (_, u_ins, _, pressures, _) = jax.lax.scan( variant_loop_over_tt, (state, obs, controller_state, expiratory_state, 0), jnp.arange(horizon)) print('test_loop_over_tt u_ins:' + str(list(u_ins))) print('test_loop_over_tt pressure:' + str(list(pressures))) self.assertTrue(jnp.allclose(u_ins, self.expected_u_ins)) self.assertTrue(jnp.allclose(pressures, self.expected_pressures))
def rollout(controller, sim, tt, use_noise, peep, pip, loss_fn, loss): """rollout function.""" waveform = BreathWaveform.create(custom_range=(peep, pip)) expiratory = Expiratory.create(waveform=waveform) controller_state = controller.init(waveform) expiratory_state = expiratory.init() sim_state, obs = sim.reset() def loop_over_tt(ctrlState_expState_simState_obs_loss, t): controller_state, expiratory_state, sim_state, obs, loss = ctrlState_expState_simState_obs_loss mean = 1.5 std = 1.0 noise = mean + std * jax.random.normal(jax.random.PRNGKey(0), shape=()) pressure = sim_state.predicted_pressure + use_noise * noise sim_state = sim_state.replace(predicted_pressure=pressure) obs = obs.replace(predicted_pressure=pressure) controller_state, u_in = controller(controller_state, obs) expiratory_state, u_out = expiratory(expiratory_state, obs) sim_state, obs = sim(sim_state, (u_in, u_out)) loss = jax.lax.cond( u_out == 0, lambda x: x + loss_fn(jnp.array(waveform.at(t)), pressure), lambda x: x, loss) return (controller_state, expiratory_state, sim_state, obs, loss), None (_, _, _, _, loss), _ = jax.lax.scan( loop_over_tt, (controller_state, expiratory_state, sim_state, obs, loss), tt) return loss
def test_decay(self): controller = Controller() waveform = BreathWaveform.create() self.assertEqual(controller.decay(waveform, 0.5), float('inf')) self.assertEqual(controller.decay(waveform, 1.0), 0.0) self.assertEqual( controller.decay(waveform, 1.5), 5 * (1 - jnp.exp(5 * (waveform.xp[2] - waveform.elapsed(1.5)))))
def setup(self): if self.waveform is None: self.waveform = BreathWaveform.create() # dynamics hyperparameters # self.time = 0.0 self.r0 = (3.0 * self.min_volume / (4.0 * jnp.pi))**(1.0 / 3.0)
def test_at(self): waveform = BreathWaveform.create() self.assertEqual(waveform.at(0.), 35.0) self.assertEqual(waveform.at(0.5), 35.0) self.assertEqual(waveform.at(1.0), 35.0) self.assertEqual(waveform.at(1.25), 20.0) self.assertEqual(waveform.at(1.5), 5) self.assertEqual(waveform.at(3.0), 35.0)
def init(self, waveform=None): if waveform is None: waveform = BreathWaveform.create() errs = jnp.array([0.0] * self.back_history_len) fwd_targets = jnp.array( [waveform.at(t * DEFAULT_DT) for t in range(self.fwd_history_len)]) state = DeepControllerState(errs=errs, fwd_targets=fwd_targets, waveform=waveform) return state
def run_balloon_lung(): """Run function.""" waveform = BreathWaveform.create() def rewards_func(state, obs, action, env, counter): del state, action, env, counter return -(obs.predicted_pressure - waveform.at(obs.time))**2 ppo.train(BalloonLung.create(), DeepAC.create(), rewards_func, horizon=29, config=get_config())
def run_learned(f): env = pickle.load(f)["model_ckpt"] agent = DeepAC.create() waveform = BreathWaveform.create() def rewards_func(state, obs, action, env, counter): del state, action, env, counter return -(obs.predicted_pressure - waveform.at(obs.time))**2 ppo.train(env=env, agent=agent, reward_fn=rewards_func, horizon=29, config=get_config())
def setup(self, waveform=None): self.model = ActorCritic() if self.params is None: self.params = self.model.init( jax.random.PRNGKey(0), jnp.expand_dims(jnp.ones([self.history_len]), axis=(0, 1)))['params'] # linear feature transform: # errs -> [average of last h errs, ..., average of last 2 errs, last err] # emulates low-pass filter bank self.featurizer = jnp.tril(jnp.ones((self.history_len, self.history_len))) self.featurizer /= jnp.expand_dims( jnp.arange(self.history_len, 0, -1), axis=0) if waveform is None: self.waveform = BreathWaveform.create() if self.normalize: self.u_scaler = u_scaler self.p_scaler = p_scaler
def test_controller(controller, sim, pips, peep): """Test controller.""" # new_controller = controller.replace(use_leaky_clamp=False) score = 0.0 horizon = 29 for pip in pips: waveform = BreathWaveform.create(peep=peep, pip=pip) result = run_controller_scan( controller, T=horizon, abort=horizon, env=sim, waveform=waveform, init_controller=True, ) analyzer = Analyzer(result) preds = analyzer.pressure # shape = (29,) truth = analyzer.target # shape = (29,) # print('preds.shape: %s', str(preds.shape)) # print('truth.shape: %s', str(truth.shape)) score += jnp.abs(preds - truth).mean() score = score / len(pips) return score
def test_elapsed(self): waveform = BreathWaveform.create() assert jnp.allclose( waveform.elapsed(1.3), waveform.elapsed(1.3 + waveform.period))
def setup(self): if self.waveform is None: self.waveform = BreathWaveform.create() self.r0 = (3 * self.min_volume / (4 * jnp.pi))**(1 / 3)
def setup(self): if self.waveform is None: self.waveform = BreathWaveform.create()
def run_controller( controller, T=1000, dt=0.03, abort=60, env=None, waveform=None, use_tqdm=False, directory=None, ): """run controller over horizon of length T.""" env = env or BalloonLung() waveform = waveform or BreathWaveform.create() expiratory = Expiratory.create(waveform=waveform) result = {} controller_state = controller.init() expiratory_state = expiratory.init() tt = range(T) if use_tqdm: tt = tqdm.tqdm(tt, leave=False) timestamps = jnp.zeros(T) pressures = jnp.zeros(T) flows = jnp.zeros(T) u_ins = jnp.zeros(T) u_outs = jnp.zeros(T) state, obs = env.reset() try: for i, _ in enumerate(tt): pressure = obs.predicted_pressure if env.should_abort(): break controller_state, u_in = controller.__call__(controller_state, obs) expiratory_state, u_out = expiratory.__call__( expiratory_state, obs) u_in = u_in.squeeze() state, obs = env(state, (u_in, u_out)) timestamps = timestamps.at[i].set(env.time(state) - dt) u_ins = u_ins.at[i].set(u_in) u_outs = u_outs.at[i].set(u_out) pressures = pressures.at[i].set(pressure) flows = flows.at[i].set(env.flow) env.wait(max(dt - env.dt, 0)) finally: env.cleanup() timeseries = { "timestamp": jnp.array(timestamps), "pressure": jnp.array(pressures), "flow": jnp.array(flows), "target": waveform.at(timestamps), "u_in": jnp.array(u_ins), "u_out": jnp.array(u_outs), } for key, val in timeseries.items(): timeseries[key] = val[:T + 1] result["timeseries"] = timeseries result["waveform"] = waveform if directory is not None: if not os.path.exists(directory): os.makedirs(directory) timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") pickle.dump(result, open(f"{directory}/{timestamp}.pkl", "wb")) return result
def run_controller_scan( controller, T=1000, dt=0.03, abort=60, env=None, waveform=None, use_tqdm=False, directory=None, init_controller=False, ): """run controller scan version.""" env = env or BalloonLung() waveform = waveform or BreathWaveform.create() expiratory = Expiratory.create(waveform=waveform) result = {} if init_controller: controller_state = controller.init(waveform) else: controller_state = controller.init() expiratory_state = expiratory.init() tt = range(T) if use_tqdm: tt = tqdm.tqdm(tt, leave=False) timestamps = jnp.zeros(T) pressures = jnp.zeros(T) flows = jnp.zeros(T) u_ins = jnp.zeros(T) u_outs = jnp.zeros(T) state, obs = env.reset() # xp = jnp.array(waveform.xp) # fp = jnp.array(waveform.fp) # period = waveform.period # dtype = waveform.dtype jit_loop_over_tt = jax.jit( functools.partial(loop_over_tt, controller=controller, expiratory=expiratory, env=env, dt=dt)) try: _, (timestamps, u_ins, u_outs, pressures, flows) = jax.lax.scan( jit_loop_over_tt, (state, obs, controller_state, expiratory_state, 0), jnp.arange(T)) finally: env.cleanup() timeseries = { "timestamp": jnp.array(timestamps), "pressure": jnp.array(pressures), "flow": jnp.array(flows), "target": waveform.at(timestamps), "u_in": jnp.array(u_ins), "u_out": jnp.array(u_outs), } for key, val in timeseries.items(): timeseries[key] = val[:T + 1] result["timeseries"] = timeseries result["waveform"] = waveform if directory is not None: if not os.path.exists(directory): os.makedirs(directory) timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") pickle.dump(result, open(f"{directory}/{timestamp}.pkl", "wb")) return result
def test_pip(self): pip = 10 waveform = BreathWaveform.create(peep=0, pip=10) self.assertEqual(pip, waveform.pip)
def test_peep(self): peep = 50 waveform = BreathWaveform.create(peep=50, pip=10) self.assertEqual(peep, waveform.peep)
def test_is_ex(self): waveform = BreathWaveform.create() self.assertEqual(waveform.is_ex(0.5), False) self.assertEqual(waveform.is_ex(1.0), False) self.assertEqual(waveform.is_ex(1.5), True)
def init(self, waveform=None): if waveform is None: waveform = BreathWaveform.create() errs = jnp.array([0.0] * self.history_len) state = DeepControllerState(errs=errs, waveform=waveform) return state
def init(self, waveform=None): if waveform is None: waveform = BreathWaveform.create() state = PIDControllerState(waveform=waveform) return state