def __call__(self, controller_state, obs): pressure, t = obs.predicted_pressure, obs.time waveform = controller_state.waveform target = waveform.at(t) err = jnp.array(target - pressure) decay = jnp.array(self.dt / (self.dt + self.RC)) p, i, d = controller_state.p, controller_state.i, controller_state.d next_p = err next_i = i + decay * (err - i) next_d = d + decay * (err - p - d) controller_state = controller_state.replace(p=next_p, i=next_i, d=next_d) next_coef = jnp.array([next_p, next_i, next_d]) u_in = self.model_apply({"params": self.params}, next_coef) u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), 100.0) # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, u_in
def __call__(self, state, obs, *args, **kwargs): action = jax.lax.dynamic_slice(self.u_ins, (state.steps.astype(int), ), (1, )) time = obs.time new_dt = jnp.max( jnp.array([DEFAULT_DT, time - proper_time(state.time)])) new_time = time new_steps = state.steps + 1 state = state.replace(time=new_time, steps=new_steps, dt=new_dt) return state, action
def __call__(self, state, obs, *args, **kwargs): time = obs.time u_out = jax.lax.cond(self.waveform.is_in(time), lambda x: 0, lambda x: 1, jnp.zeros_like(time)) new_dt = jnp.max( jnp.array([DEFAULT_DT, time - proper_time(state.time)])) new_time = time new_steps = state.steps + 1 state = state.replace(time=new_time, steps=new_steps, dt=new_dt) return state, u_out
def __call__(self, controller_state, obs): state, t = obs.predicted_pressure, obs.time errs, waveform = controller_state.errs, controller_state.waveform fwd_targets = controller_state.fwd_targets target = waveform.at(t) fwd_t = t + self.fwd_history_len * DEFAULT_DT if self.fwd_history_len > 0: fwd_target = jax.lax.cond(fwd_t >= self.horizon * DEFAULT_DT, lambda x: fwd_targets[-1], lambda x: waveform.at(fwd_t), None) if self.normalize: target_normalized = self.p_normalizer(target).squeeze() state_normalized = self.p_normalizer(state).squeeze() next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target_normalized - state_normalized) if self.fwd_history_len > 0: fwd_target_normalized = self.p_normalizer(fwd_target).squeeze() next_fwd_targets = jnp.roll(fwd_targets, shift=-1) next_fwd_targets = next_fwd_targets.at[-1].set( fwd_target_normalized) else: next_fwd_targets = jnp.array([]) else: next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target - state) if self.fwd_history_len > 0: next_fwd_targets = jnp.roll(fwd_targets, shift=-1) next_fwd_targets = next_fwd_targets.at[-1].set(fwd_target) else: next_fwd_targets = jnp.array([]) controller_state = controller_state.replace( errs=next_errs, fwd_targets=next_fwd_targets) decay = self.decay(waveform, t) def true_func(null_arg): trajectory = jnp.hstack([next_errs, next_fwd_targets]) u_in = self.model_apply({"params": self.params}, trajectory) return u_in.squeeze().astype(jnp.float64) # changed decay compare from None to float(inf) due to cond requirements u_in = jax.lax.cond(jnp.isinf(decay), true_func, lambda x: jnp.array(decay), None) u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), self.clip).squeeze() # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, u_in
def __call__(self, controller_state, obs): pressure, t = obs.predicted_pressure, obs.time target = self.waveform.at(t) action = jax.lax.cond(pressure < target, lambda x: self.max_action, lambda x: self.min_action, None) # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, action
def __call__(self, controller_state, obs): state, t = obs.predicted_pressure, obs.time errs, waveform = controller_state.errs, controller_state.waveform target = waveform.at(t) if self.normalize: target_normalized = self.p_normalizer(target).squeeze() state_normalized = self.p_normalizer(state).squeeze() next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target_normalized - state_normalized) else: next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target - state) controller_state = controller_state.replace(errs=next_errs) decay = self.decay(waveform, t) def true_func(null_arg): trajectory = jnp.expand_dims(next_errs[-self.history_len:], axis=(0, 1)) input_val = jnp.reshape((trajectory @ self.featurizer), (1, self.history_len, 1)) u_in = self.model_apply({"params": self.params}, input_val) return u_in.squeeze().astype(jnp.float32) # changed decay compare from None to float(inf) due to cond requirements u_in = jax.lax.cond(jnp.isinf(decay), true_func, lambda x: jnp.array(decay), None) # Implementing "leaky" clamp to solve the zero gradient problem if self.use_leaky_clamp: u_in = jax.lax.cond(u_in < 0.0, lambda x: x * 0.01, lambda x: x, u_in) u_in = jax.lax.cond(u_in > self.clip, lambda x: self.clip + x * 0.01, lambda x: x, u_in) else: u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float32), self.clip).squeeze() # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace(time=new_time, steps=new_steps, dt=new_dt) return controller_state, u_in
def __call__(self, controller_state, obs): state, t = obs.predicted_pressure, obs.time errs, waveform = controller_state.errs, self.waveform target = waveform.at(t) if self.normalize: target_scaled = self.p_scaler(target).squeeze() state_scaled = self.p_scaler(state).squeeze() next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target_scaled - state_scaled) else: next_errs = jnp.roll(errs, shift=-1) next_errs = next_errs.at[-1].set(target - state) current_key, next_key = jax.random.split(controller_state.key) ## adding a batch dimension ## next_errs_expanded = jnp.expand_dims(next_errs, axis=(0)) controller_state_big = controller_state.replace( errs=next_errs_expanded, key=current_key) decay = self.decay(waveform, t) log_prob, value = self.derive_prob_and_value(controller_state_big) #value = value[0] prob = jnp.exp(log_prob) # environment step u_in = jax.random.choice(current_key, prob.shape[0], p=prob) # changed decay compare from None to float(inf) due to cond requirements ## TODO(@namanagarwal) - Need to do this # u_in, log_prob, value = jax.lax.cond( # jnp.isinf(decay), true_func, lambda x: (jnp.array(decay), None, None), None) # TODO (@namanagarwal) : Figure out what to do with clamping #u_in = jax.lax.clamp(0.0, u_in.astype(jnp.float64), self.clip).squeeze() # update controller_state new_dt = jnp.max( jnp.array([DEFAULT_DT, t - proper_time(controller_state.time)])) new_time = t new_steps = controller_state.steps + 1 controller_state = controller_state.replace( time=new_time, steps=new_steps, dt=new_dt, key=next_key, errs=next_errs) return controller_state, (u_in, 0), log_prob[u_in], value
def test_proper_time(self): self.assertEqual(proper_time(float('inf')), 0) self.assertEqual(proper_time(10.0), 10.0)