Exemplo n.º 1
0
    def test_loop_over_tt(self, variant):
        dt = 0.03
        horizon = 29
        waveform = BreathWaveform.create()
        expiratory = Expiratory.create(waveform=waveform)
        controller_state = self.controller.init()
        expiratory_state = expiratory.init()
        state, obs = self.sim.reset()
        """jit_loop_over_tt = jax.jit(

        functools.partial(
            loop_over_tt,
            controller=self.controller,
            expiratory=expiratory,
            env=self.sim,
            dt=dt))
    """
        variant_loop_over_tt = variant(
            functools.partial(loop_over_tt,
                              controller=self.controller,
                              expiratory=expiratory,
                              env=self.sim,
                              dt=dt))

        _, (_, u_ins, _, pressures, _) = jax.lax.scan(
            variant_loop_over_tt,
            (state, obs, controller_state, expiratory_state, 0),
            jnp.arange(horizon))
        print('test_loop_over_tt u_ins:' + str(list(u_ins)))
        print('test_loop_over_tt pressure:' + str(list(pressures)))

        self.assertTrue(jnp.allclose(u_ins, self.expected_u_ins))
        self.assertTrue(jnp.allclose(pressures, self.expected_pressures))
Exemplo n.º 2
0
def rollout(controller, sim, tt, use_noise, peep, pip, loss_fn, loss):
    """rollout function."""
    waveform = BreathWaveform.create(custom_range=(peep, pip))
    expiratory = Expiratory.create(waveform=waveform)
    controller_state = controller.init(waveform)
    expiratory_state = expiratory.init()
    sim_state, obs = sim.reset()

    def loop_over_tt(ctrlState_expState_simState_obs_loss, t):
        controller_state, expiratory_state, sim_state, obs, loss = ctrlState_expState_simState_obs_loss
        mean = 1.5
        std = 1.0
        noise = mean + std * jax.random.normal(jax.random.PRNGKey(0), shape=())
        pressure = sim_state.predicted_pressure + use_noise * noise
        sim_state = sim_state.replace(predicted_pressure=pressure)
        obs = obs.replace(predicted_pressure=pressure)
        controller_state, u_in = controller(controller_state, obs)
        expiratory_state, u_out = expiratory(expiratory_state, obs)
        sim_state, obs = sim(sim_state, (u_in, u_out))
        loss = jax.lax.cond(
            u_out == 0,
            lambda x: x + loss_fn(jnp.array(waveform.at(t)), pressure),
            lambda x: x, loss)
        return (controller_state, expiratory_state, sim_state, obs, loss), None

    (_, _, _, _, loss), _ = jax.lax.scan(
        loop_over_tt,
        (controller_state, expiratory_state, sim_state, obs, loss), tt)
    return loss
Exemplo n.º 3
0
 def test_decay(self):
   controller = Controller()
   waveform = BreathWaveform.create()
   self.assertEqual(controller.decay(waveform, 0.5), float('inf'))
   self.assertEqual(controller.decay(waveform, 1.0), 0.0)
   self.assertEqual(
       controller.decay(waveform, 1.5),
       5 * (1 - jnp.exp(5 * (waveform.xp[2] - waveform.elapsed(1.5)))))
Exemplo n.º 4
0
    def setup(self):
        if self.waveform is None:
            self.waveform = BreathWaveform.create()

        # dynamics hyperparameters
        # self.time = 0.0

        self.r0 = (3.0 * self.min_volume / (4.0 * jnp.pi))**(1.0 / 3.0)
Exemplo n.º 5
0
 def test_at(self):
   waveform = BreathWaveform.create()
   self.assertEqual(waveform.at(0.), 35.0)
   self.assertEqual(waveform.at(0.5), 35.0)
   self.assertEqual(waveform.at(1.0), 35.0)
   self.assertEqual(waveform.at(1.25), 20.0)
   self.assertEqual(waveform.at(1.5), 5)
   self.assertEqual(waveform.at(3.0), 35.0)
Exemplo n.º 6
0
 def init(self, waveform=None):
     if waveform is None:
         waveform = BreathWaveform.create()
     errs = jnp.array([0.0] * self.back_history_len)
     fwd_targets = jnp.array(
         [waveform.at(t * DEFAULT_DT) for t in range(self.fwd_history_len)])
     state = DeepControllerState(errs=errs,
                                 fwd_targets=fwd_targets,
                                 waveform=waveform)
     return state
Exemplo n.º 7
0
def run_balloon_lung():
    """Run function."""
    waveform = BreathWaveform.create()

    def rewards_func(state, obs, action, env, counter):
        del state, action, env, counter
        return -(obs.predicted_pressure - waveform.at(obs.time))**2

    ppo.train(BalloonLung.create(),
              DeepAC.create(),
              rewards_func,
              horizon=29,
              config=get_config())
Exemplo n.º 8
0
def run_learned(f):
    env = pickle.load(f)["model_ckpt"]

    agent = DeepAC.create()
    waveform = BreathWaveform.create()

    def rewards_func(state, obs, action, env, counter):
        del state, action, env, counter
        return -(obs.predicted_pressure - waveform.at(obs.time))**2

    ppo.train(env=env,
              agent=agent,
              reward_fn=rewards_func,
              horizon=29,
              config=get_config())
Exemplo n.º 9
0
  def setup(self, waveform=None):
    self.model = ActorCritic()
    if self.params is None:
      self.params = self.model.init(
          jax.random.PRNGKey(0),
          jnp.expand_dims(jnp.ones([self.history_len]), axis=(0, 1)))['params']

    # linear feature transform:
    # errs -> [average of last h errs, ..., average of last 2 errs, last err]
    # emulates low-pass filter bank

    self.featurizer = jnp.tril(jnp.ones((self.history_len, self.history_len)))
    self.featurizer /= jnp.expand_dims(
        jnp.arange(self.history_len, 0, -1), axis=0)

    if waveform is None:
      self.waveform = BreathWaveform.create()

    if self.normalize:
      self.u_scaler = u_scaler
      self.p_scaler = p_scaler
Exemplo n.º 10
0
def test_controller(controller, sim, pips, peep):
    """Test controller."""
    # new_controller = controller.replace(use_leaky_clamp=False)
    score = 0.0
    horizon = 29
    for pip in pips:
        waveform = BreathWaveform.create(peep=peep, pip=pip)
        result = run_controller_scan(
            controller,
            T=horizon,
            abort=horizon,
            env=sim,
            waveform=waveform,
            init_controller=True,
        )
        analyzer = Analyzer(result)
        preds = analyzer.pressure  # shape = (29,)
        truth = analyzer.target  # shape = (29,)
        # print('preds.shape: %s', str(preds.shape))
        # print('truth.shape: %s', str(truth.shape))
        score += jnp.abs(preds - truth).mean()
    score = score / len(pips)
    return score
Exemplo n.º 11
0
 def test_elapsed(self):
   waveform = BreathWaveform.create()
   assert jnp.allclose(
       waveform.elapsed(1.3), waveform.elapsed(1.3 + waveform.period))
Exemplo n.º 12
0
 def setup(self):
     if self.waveform is None:
         self.waveform = BreathWaveform.create()
     self.r0 = (3 * self.min_volume / (4 * jnp.pi))**(1 / 3)
Exemplo n.º 13
0
 def setup(self):
     if self.waveform is None:
         self.waveform = BreathWaveform.create()
Exemplo n.º 14
0
def run_controller(
    controller,
    T=1000,
    dt=0.03,
    abort=60,
    env=None,
    waveform=None,
    use_tqdm=False,
    directory=None,
):
    """run controller over horizon of length T."""
    env = env or BalloonLung()
    waveform = waveform or BreathWaveform.create()
    expiratory = Expiratory.create(waveform=waveform)

    result = {}

    controller_state = controller.init()
    expiratory_state = expiratory.init()

    tt = range(T)
    if use_tqdm:
        tt = tqdm.tqdm(tt, leave=False)

    timestamps = jnp.zeros(T)
    pressures = jnp.zeros(T)
    flows = jnp.zeros(T)
    u_ins = jnp.zeros(T)
    u_outs = jnp.zeros(T)

    state, obs = env.reset()

    try:
        for i, _ in enumerate(tt):
            pressure = obs.predicted_pressure
            if env.should_abort():
                break

            controller_state, u_in = controller.__call__(controller_state, obs)
            expiratory_state, u_out = expiratory.__call__(
                expiratory_state, obs)

            u_in = u_in.squeeze()
            state, obs = env(state, (u_in, u_out))

            timestamps = timestamps.at[i].set(env.time(state) - dt)
            u_ins = u_ins.at[i].set(u_in)
            u_outs = u_outs.at[i].set(u_out)
            pressures = pressures.at[i].set(pressure)
            flows = flows.at[i].set(env.flow)

            env.wait(max(dt - env.dt, 0))

    finally:
        env.cleanup()

    timeseries = {
        "timestamp": jnp.array(timestamps),
        "pressure": jnp.array(pressures),
        "flow": jnp.array(flows),
        "target": waveform.at(timestamps),
        "u_in": jnp.array(u_ins),
        "u_out": jnp.array(u_outs),
    }

    for key, val in timeseries.items():
        timeseries[key] = val[:T + 1]

    result["timeseries"] = timeseries
    result["waveform"] = waveform

    if directory is not None:
        if not os.path.exists(directory):
            os.makedirs(directory)
        timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
        pickle.dump(result, open(f"{directory}/{timestamp}.pkl", "wb"))

    return result
Exemplo n.º 15
0
def run_controller_scan(
    controller,
    T=1000,
    dt=0.03,
    abort=60,
    env=None,
    waveform=None,
    use_tqdm=False,
    directory=None,
    init_controller=False,
):
    """run controller scan version."""
    env = env or BalloonLung()
    waveform = waveform or BreathWaveform.create()
    expiratory = Expiratory.create(waveform=waveform)

    result = {}

    if init_controller:
        controller_state = controller.init(waveform)
    else:
        controller_state = controller.init()
    expiratory_state = expiratory.init()

    tt = range(T)
    if use_tqdm:
        tt = tqdm.tqdm(tt, leave=False)

    timestamps = jnp.zeros(T)
    pressures = jnp.zeros(T)
    flows = jnp.zeros(T)
    u_ins = jnp.zeros(T)
    u_outs = jnp.zeros(T)

    state, obs = env.reset()
    # xp = jnp.array(waveform.xp)
    # fp = jnp.array(waveform.fp)
    # period = waveform.period
    # dtype = waveform.dtype

    jit_loop_over_tt = jax.jit(
        functools.partial(loop_over_tt,
                          controller=controller,
                          expiratory=expiratory,
                          env=env,
                          dt=dt))
    try:
        _, (timestamps, u_ins, u_outs, pressures, flows) = jax.lax.scan(
            jit_loop_over_tt,
            (state, obs, controller_state, expiratory_state, 0), jnp.arange(T))

    finally:
        env.cleanup()

    timeseries = {
        "timestamp": jnp.array(timestamps),
        "pressure": jnp.array(pressures),
        "flow": jnp.array(flows),
        "target": waveform.at(timestamps),
        "u_in": jnp.array(u_ins),
        "u_out": jnp.array(u_outs),
    }

    for key, val in timeseries.items():
        timeseries[key] = val[:T + 1]

    result["timeseries"] = timeseries
    result["waveform"] = waveform

    if directory is not None:
        if not os.path.exists(directory):
            os.makedirs(directory)
        timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
        pickle.dump(result, open(f"{directory}/{timestamp}.pkl", "wb"))

    return result
Exemplo n.º 16
0
  def test_pip(self):
    pip = 10
    waveform = BreathWaveform.create(peep=0, pip=10)

    self.assertEqual(pip, waveform.pip)
Exemplo n.º 17
0
  def test_peep(self):
    peep = 50
    waveform = BreathWaveform.create(peep=50, pip=10)

    self.assertEqual(peep, waveform.peep)
Exemplo n.º 18
0
 def test_is_ex(self):
   waveform = BreathWaveform.create()
   self.assertEqual(waveform.is_ex(0.5), False)
   self.assertEqual(waveform.is_ex(1.0), False)
   self.assertEqual(waveform.is_ex(1.5), True)
Exemplo n.º 19
0
 def init(self, waveform=None):
     if waveform is None:
         waveform = BreathWaveform.create()
     errs = jnp.array([0.0] * self.history_len)
     state = DeepControllerState(errs=errs, waveform=waveform)
     return state
Exemplo n.º 20
0
 def init(self, waveform=None):
     if waveform is None:
         waveform = BreathWaveform.create()
     state = PIDControllerState(waveform=waveform)
     return state