Esempio n. 1
0
    def __call__(self, controller_state, obs):
        pressure, t = obs.predicted_pressure, obs.time
        waveform = controller_state.waveform
        target = waveform.at(t)
        err = jnp.array(target - pressure)
        decay = jnp.array(self.dt / (self.dt + self.RC))
        p, i, d = controller_state.p, controller_state.i, controller_state.d
        next_p = err
        next_i = i + decay * (err - i)
        next_d = d + decay * (err - p - d)
        controller_state = controller_state.replace(p=next_p,
                                                    i=next_i,
                                                    d=next_d)
        next_coef = jnp.array([next_p, next_i, next_d])
        u_in = self.model_apply({"params": self.params}, next_coef)
        u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), 100.0)

        # update controller_state
        new_dt = jnp.max(
            jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
        new_time = t
        new_steps = controller_state.steps + 1
        controller_state = controller_state.replace(time=new_time,
                                                    steps=new_steps,
                                                    dt=new_dt)
        return controller_state, u_in
Esempio n. 2
0
 def __call__(self, state, obs, *args, **kwargs):
     action = jax.lax.dynamic_slice(self.u_ins, (state.steps.astype(int), ),
                                    (1, ))
     time = obs.time
     new_dt = jnp.max(
         jnp.array([DEFAULT_DT, time - proper_time(state.time)]))
     new_time = time
     new_steps = state.steps + 1
     state = state.replace(time=new_time, steps=new_steps, dt=new_dt)
     return state, action
Esempio n. 3
0
 def __call__(self, state, obs, *args, **kwargs):
     time = obs.time
     u_out = jax.lax.cond(self.waveform.is_in(time), lambda x: 0,
                          lambda x: 1, jnp.zeros_like(time))
     new_dt = jnp.max(
         jnp.array([DEFAULT_DT, time - proper_time(state.time)]))
     new_time = time
     new_steps = state.steps + 1
     state = state.replace(time=new_time, steps=new_steps, dt=new_dt)
     return state, u_out
Esempio n. 4
0
    def __call__(self, controller_state, obs):
        state, t = obs.predicted_pressure, obs.time
        errs, waveform = controller_state.errs, controller_state.waveform
        fwd_targets = controller_state.fwd_targets
        target = waveform.at(t)
        fwd_t = t + self.fwd_history_len * DEFAULT_DT
        if self.fwd_history_len > 0:
            fwd_target = jax.lax.cond(fwd_t >= self.horizon * DEFAULT_DT,
                                      lambda x: fwd_targets[-1],
                                      lambda x: waveform.at(fwd_t), None)
        if self.normalize:
            target_normalized = self.p_normalizer(target).squeeze()
            state_normalized = self.p_normalizer(state).squeeze()
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target_normalized -
                                             state_normalized)
            if self.fwd_history_len > 0:
                fwd_target_normalized = self.p_normalizer(fwd_target).squeeze()
                next_fwd_targets = jnp.roll(fwd_targets, shift=-1)
                next_fwd_targets = next_fwd_targets.at[-1].set(
                    fwd_target_normalized)
            else:
                next_fwd_targets = jnp.array([])
        else:
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target - state)
            if self.fwd_history_len > 0:
                next_fwd_targets = jnp.roll(fwd_targets, shift=-1)
                next_fwd_targets = next_fwd_targets.at[-1].set(fwd_target)
            else:
                next_fwd_targets = jnp.array([])
        controller_state = controller_state.replace(
            errs=next_errs, fwd_targets=next_fwd_targets)
        decay = self.decay(waveform, t)

        def true_func(null_arg):
            trajectory = jnp.hstack([next_errs, next_fwd_targets])
            u_in = self.model_apply({"params": self.params}, trajectory)
            return u_in.squeeze().astype(jnp.float64)

        # changed decay compare from None to float(inf) due to cond requirements
        u_in = jax.lax.cond(jnp.isinf(decay), true_func,
                            lambda x: jnp.array(decay), None)
        u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64),
                             self.clip).squeeze()
        # update controller_state
        new_dt = jnp.max(
            jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
        new_time = t
        new_steps = controller_state.steps + 1
        controller_state = controller_state.replace(time=new_time,
                                                    steps=new_steps,
                                                    dt=new_dt)
        return controller_state, u_in
Esempio n. 5
0
 def __call__(self, controller_state, obs):
     pressure, t = obs.predicted_pressure, obs.time
     target = self.waveform.at(t)
     action = jax.lax.cond(pressure < target, lambda x: self.max_action,
                           lambda x: self.min_action, None)
     # update controller_state
     new_dt = jnp.max(
         jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
     new_time = t
     new_steps = controller_state.steps + 1
     controller_state = controller_state.replace(time=new_time,
                                                 steps=new_steps,
                                                 dt=new_dt)
     return controller_state, action
Esempio n. 6
0
    def __call__(self, controller_state, obs):
        state, t = obs.predicted_pressure, obs.time
        errs, waveform = controller_state.errs, controller_state.waveform
        target = waveform.at(t)
        if self.normalize:
            target_normalized = self.p_normalizer(target).squeeze()
            state_normalized = self.p_normalizer(state).squeeze()
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target_normalized -
                                             state_normalized)
        else:
            next_errs = jnp.roll(errs, shift=-1)
            next_errs = next_errs.at[-1].set(target - state)
        controller_state = controller_state.replace(errs=next_errs)
        decay = self.decay(waveform, t)

        def true_func(null_arg):
            trajectory = jnp.expand_dims(next_errs[-self.history_len:],
                                         axis=(0, 1))
            input_val = jnp.reshape((trajectory @ self.featurizer),
                                    (1, self.history_len, 1))
            u_in = self.model_apply({"params": self.params}, input_val)
            return u_in.squeeze().astype(jnp.float32)

        # changed decay compare from None to float(inf) due to cond requirements
        u_in = jax.lax.cond(jnp.isinf(decay), true_func,
                            lambda x: jnp.array(decay), None)
        # Implementing "leaky" clamp to solve the zero gradient problem
        if self.use_leaky_clamp:
            u_in = jax.lax.cond(u_in < 0.0, lambda x: x * 0.01, lambda x: x,
                                u_in)
            u_in = jax.lax.cond(u_in > self.clip,
                                lambda x: self.clip + x * 0.01, lambda x: x,
                                u_in)
        else:
            u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float32),
                                 self.clip).squeeze()
        # update controller_state
        new_dt = jnp.max(
            jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
        new_time = t
        new_steps = controller_state.steps + 1
        controller_state = controller_state.replace(time=new_time,
                                                    steps=new_steps,
                                                    dt=new_dt)
        return controller_state, u_in
Esempio n. 7
0
  def __call__(self, controller_state, obs):
    state, t = obs.predicted_pressure, obs.time
    errs, waveform = controller_state.errs, self.waveform
    target = waveform.at(t)
    if self.normalize:
      target_scaled = self.p_scaler(target).squeeze()
      state_scaled = self.p_scaler(state).squeeze()
      next_errs = jnp.roll(errs, shift=-1)
      next_errs = next_errs.at[-1].set(target_scaled - state_scaled)
    else:
      next_errs = jnp.roll(errs, shift=-1)
      next_errs = next_errs.at[-1].set(target - state)

    current_key, next_key = jax.random.split(controller_state.key)
    ## adding a batch dimension ##
    next_errs_expanded = jnp.expand_dims(next_errs, axis=(0))

    controller_state_big = controller_state.replace(
        errs=next_errs_expanded, key=current_key)
    decay = self.decay(waveform, t)
    log_prob, value = self.derive_prob_and_value(controller_state_big)
    #value = value[0]
    prob = jnp.exp(log_prob)
    # environment step
    u_in = jax.random.choice(current_key, prob.shape[0], p=prob)

    # changed decay compare from None to float(inf) due to cond requirements
    ## TODO(@namanagarwal) - Need to do this
    # u_in, log_prob, value = jax.lax.cond(
    #     jnp.isinf(decay), true_func, lambda x: (jnp.array(decay), None, None), None)

    # TODO (@namanagarwal) : Figure out what to do with clamping
    #u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), self.clip).squeeze()

    # update controller_state

    new_dt = jnp.max(
        jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)]))
    new_time = t
    new_steps = controller_state.steps + 1
    controller_state = controller_state.replace(
        time=new_time, steps=new_steps, dt=new_dt, key=next_key, errs=next_errs)
    return controller_state, (u_in, 0), log_prob[u_in], value
Esempio n. 8
0
 def test_proper_time(self):
   self.assertEqual(proper_time(float('inf')), 0)
   self.assertEqual(proper_time(10.0), 10.0)