def __init__(self, env, gatepolicy, policy_func, value_func, n_options, option_len=3, timesteps_per_batch=1000, gamma=0.99, lam=0.97, MI_lambda=1e-3, gate_max_kl=1e-2, option_max_kl=1e-2, cg_iters=10, cg_damping=1e-2, vf_iters=2, max_train=1000, ls_step=0.5, checkpoint_freq=50): super(GateTRPO, self).__init__(name=env.name) self.n_options = n_options #self.name = self.name self.env = env self.gamma = gamma self.lam = lam self.MI_lambda = MI_lambda self.current_option = 0 self.timesteps_per_batch = timesteps_per_batch self.gate_max_kl = gate_max_kl self.cg_iters = cg_iters self.cg_damping = cg_damping self.max_train = max_train self.ls_step = ls_step self.checkpoint_freq = checkpoint_freq self.policy = gatepolicy(env.observation_space.shape, env.action_space.n) self.oldpolicy = gatepolicy(env.observation_space.shape, env.action_space.n, verbose=0) self.oldpolicy.disable_grad() self.progbar = Progbar(self.timesteps_per_batch) self.path_generator = self.roller() self.episodes_reward = collections.deque([], 5) self.episodes_len = collections.deque([], 5) self.done = 0 self.functions = [self.policy] self.options = [ OptionTRPO(env.name, i, env, policy_func, value_func, gamma, lam, option_len, option_max_kl, cg_iters, cg_damping, vf_iters, ls_step, self.logger, checkpoint_freq) for i in range(n_options) ]
def __init__( self, env, policy_func, value_func, timesteps_per_batch=1000, # what to train on gamma=0.99, lam=0.97, # advantage estimation max_kl=1e-2, cg_iters=10, entropy_coeff=0.0, cg_damping=1e-2, vf_iters=3, max_train=1000, checkpoint_freq=50): super(TRPO, self).__init__() self.env = env self.gamma = gamma self.lam = lam self.timesteps_per_batch = timesteps_per_batch self.max_kl = max_kl self.cg_iters = cg_iters self.entropy_coeff = entropy_coeff self.cg_damping = cg_damping self.vf_iters = vf_iters self.max_train = max_train self.checkpoint_freq = checkpoint_freq self.policy = policy_func(env) self.oldpolicy = policy_func(env) self.value_function = value_func(self.env) self.progbar = Progbar(self.timesteps_per_batch) self.path_generator = self.roller() self.value_function.summary() self.episodes_reward = collections.deque([], 50) self.episodes_len = collections.deque([], 50) self.done = 0 self.functions = [self.policy, self.value_function]
def __init__(self, env, deep_func, gamma, batch_size, memory_min, memory_max, update_double=10000, train_steps=1000000, log_freq=1000, eps_start=1, eps_decay=-1, eps_min=0.1): super(DDQN, self).__init__() self.env = env self.Q = self.model = deep_func(env) self.target_Q = deep_func(env) self.discount = gamma self.memory_min = memory_min self.memory_max = memory_max self.eps = eps_start self.train_steps = train_steps self.batch_size = batch_size self.done = 0 self.log_freq = log_freq self.progbar = Progbar(self.memory_max) self.memory = ReplayMemory( self.memory_max, ["state", "action", "reward", "next_state", "terminated"]) self.eps_decay = eps_decay if eps_decay == -1: self.eps_decay = 1 / train_steps self.eps_min = eps_min self.update_double = update_double self.actions = [] self.path_generator = self.roller() self.past_rewards = collections.deque([], 50) self.functions = [self.Q]
def __init__(self, input_shape, output_shape): super(BaseNetwork, self).__init__() self.input_shape = input_shape self.output_shape = output_shape self.progbar = Progbar(100)
class BaseNetwork(nn.Module): """ Base class for our Neural networks """ name = "BaseNetwork" def __init__(self, input_shape, output_shape): super(BaseNetwork, self).__init__() self.input_shape = input_shape self.output_shape = output_shape self.progbar = Progbar(100) def forward(self, x): return self.model(x) def predict(self, x): x = U.torchify(x) if len(x.shape) == len(self.input_shape): x.unsqueeze_(0) return U.get(self.forward(x).squeeze()) def optimize(self, l, clip=False): self.optimizer.zero_grad() l.backward() if clip: nn.utils.clip_grad_norm_(self.parameters(), 1.0) self.optimizer.step() def fit(self, X, Y, batch_size=50, epochs=1, clip=False): Xtmp, Ytmp = X.split(batch_size), Y.split(batch_size) for _ in range(epochs): self.progbar.__init__(len(Xtmp)) for x, y in zip(Xtmp, Ytmp): #self.optimizer.zero_grad() loss = self.loss(self.forward(x), y) self.optimize(loss, clip) new_loss = self.loss(self.forward(x).detach(), y) self.progbar.add(1, values=[("old", U.get(loss)), ("new", U.get(new_loss))]) def step(self, grad): self.optimizer.zero_grad() self.flaten.set_grad(grad) self.optimizer.step() def set_learning_rate(self, rate, verbose=0): for param_group in self.optimizer.param_groups: param_group['lr'] = rate if verbose: print("\n New Learning Rate ", param_group['lr'], "\n") def get_learning_rate(self): for param_group in self.optimizer.param_groups: return param_group['lr'] def compile(self): self.loss = nn.SmoothL1Loss() self.optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-4) self.flaten = Flattener(self.parameters) self.summary() def load(self, fname): print("Loading %s.%s" % (fname, self.name)) dic = torch.load(U.CHECK_PATH + fname + "." + self.name) super(BaseNetwork, self).load_state_dict(dic) def save(self, fname): print("Saving to %s.%s" % (fname, self.name)) dic = super(BaseNetwork, self).state_dict() torch.save(dic, "%s.%s" % (U.CHECK_PATH + fname, self.name)) def copy(self, X): self.flaten.set(X.flaten.get()) def summary(self): U.summary(self, self.input_shape)
class TRPO(BaseAgent): name = "TRPO" def __init__( self, env, policy_func, value_func, timesteps_per_batch=1000, # what to train on gamma=0.99, lam=0.97, # advantage estimation max_kl=1e-2, cg_iters=10, entropy_coeff=0.0, cg_damping=1e-2, vf_iters=3, max_train=1000, checkpoint_freq=50): super(TRPO, self).__init__() self.env = env self.gamma = gamma self.lam = lam self.timesteps_per_batch = timesteps_per_batch self.max_kl = max_kl self.cg_iters = cg_iters self.entropy_coeff = entropy_coeff self.cg_damping = cg_damping self.vf_iters = vf_iters self.max_train = max_train self.checkpoint_freq = checkpoint_freq self.policy = policy_func(env) self.oldpolicy = policy_func(env) self.value_function = value_func(self.env) self.progbar = Progbar(self.timesteps_per_batch) self.path_generator = self.roller() self.value_function.summary() self.episodes_reward = collections.deque([], 50) self.episodes_len = collections.deque([], 50) self.done = 0 self.functions = [self.policy, self.value_function] def act(self, state, train=True): if train: return self.policy.sample(state) return self.policy.act(state) def calculate_losses(self, states, actions, advantages, tdlamret): pi = self.policy(states) old_pi = self.oldpolicy(states).detach() kl_old_new = m_utils.kl_logits(old_pi, pi) entropy = m_utils.entropy_logits(pi) mean_kl = kl_old_new.mean() mean_entropy = entropy.mean() ratio = torch.exp( m_utils.logp(pi, actions) - m_utils.logp(old_pi, actions)) # advantage * pnew / pold surrogate_gain = (ratio * advantages).mean() optimization_gain = surrogate_gain + self.entropy_coeff * mean_entropy losses = { "gain": optimization_gain, "meankl": mean_kl, "entropy": mean_entropy, "surrogate": surrogate_gain, "mean_entropy": mean_entropy } return losses def train(self): while self.done < self.max_train: print("=" * 40) print(" " * 15, self.done, "\n") self._train() if not self.done % self.checkpoint_freq: self.save() self.done = self.done + 1 self.done = 0 def _train(self): # Prepare for rollouts # ---------------------------------------- self.oldpolicy.copy(self.policy) path = self.path_generator.__next__() states = U.torchify(path["state"]) actions = U.torchify(path["action"]).long() advantages = U.torchify(path["advantage"]) tdlamret = U.torchify(path["tdlamret"]) vpred = U.torchify( path["vf"]) # predicted value function before udpate advantages = (advantages - advantages.mean()) / advantages.std( ) # standardized advantage function estimate losses = self.calculate_losses(states, actions, advantages, tdlamret) kl = losses["meankl"] optimization_gain = losses["gain"] loss_grad = self.policy.flaten.flatgrad(optimization_gain, retain=True) grad_kl = self.policy.flaten.flatgrad(kl, create=True, retain=True) theta_before = self.policy.flaten.get() self.log("Init param sum", theta_before.sum()) self.log("explained variance", (vpred - tdlamret).var() / tdlamret.var()) if np.allclose(loss_grad.detach().cpu().numpy(), 0, atol=1e-15): print("Got zero gradient. not updating") else: print("Conjugate Gradient", end="") start = time.time() stepdir = m_utils.conjugate_gradient(self.Fvp(grad_kl), loss_grad, cg_iters=self.cg_iters) elapsed = time.time() - start print(", Done in %.3f" % elapsed) self.log("Conjugate Gradient in s", elapsed) assert stepdir.sum() != float("Inf") shs = .5 * stepdir.dot(self.Fvp(grad_kl)(stepdir)) lm = torch.sqrt(shs / self.max_kl) self.log("lagrange multiplier:", lm) self.log("gnorm:", np.linalg.norm(loss_grad.cpu().detach().numpy())) fullstep = stepdir / lm expected_improve = loss_grad.dot(fullstep) surrogate_before = losses["surrogate"] stepsize = 1.0 print("Line Search", end="") start = time.time() for _ in range(10): theta_new = theta_before + fullstep * stepsize self.policy.flaten.set(theta_new) losses = self.calculate_losses(states, actions, advantages, tdlamret) surr = losses["surrogate"] improve = surr - surrogate_before kl = losses["meankl"] if surr == float("Inf") or kl == float("Inf"): print("Infinite value of losses") elif kl > self.max_kl: print("Violated KL") elif improve < 0: print("Surrogate didn't improve. shrinking step.") else: print("Expected: %.3f Actual: %.3f" % (expected_improve, improve)) print("Stepsize OK!") self.log("Line Search", "OK") break stepsize *= .5 else: print("couldn't compute a good step") self.log("Line Search", "NOPE") self.policy.flaten.set(theta_before) elapsed = time.time() - start print(", Done in %.3f" % elapsed) self.log("Line Search in s", elapsed) self.log("KL", kl) self.log("Surrogate", surr) start = time.time() print("Value Function Update", end="") self.value_function.fit(states[::5], tdlamret[::5], batch_size=50, epochs=self.vf_iters) elapsed = time.time() - start print(", Done in %.3f" % elapsed) self.log("Value Function Fitting in s", elapsed) self.log("TDlamret mean", tdlamret.mean()) self.log("Last 50 rolls mean rew", np.mean(self.episodes_reward)) self.log("Last 50 rolls mean len", np.mean(self.episodes_len)) self.print() def roller(self): state = self.env.reset() ep_rews = 0 ep_len = 0 while True: path = { s: [] for s in ["state", "action", "reward", "terminated", "vf", "next_vf"] } self.progbar.__init__(self.timesteps_per_batch) for _ in range(self.timesteps_per_batch): path["state"].append(state) # act action = self.act(state) vf = self.value_function.predict(state) state, rew, done, _ = self.env.step(action) path["action"].append(action) path["reward"].append(rew) path["vf"].append(vf) path["terminated"].append(done * 1.0) path["next_vf"].append((1 - done) * vf) ep_rews += rew ep_len += 1 if done: state = self.env.reset() self.episodes_reward.append(ep_rews) self.episodes_len.append(ep_len) ep_rews = 0 ep_len = 0 self.progbar.add(1) for k, v in path.items(): path[k] = np.array(v) self.add_vtarg_and_adv(path) yield path def Fvp(self, grad_kl): def fisher_product(v): kl_v = (grad_kl * v).sum() grad_grad_kl = self.policy.flaten.flatgrad(kl_v, retain=True) return grad_grad_kl + v * self.cg_damping return fisher_product def add_vtarg_and_adv(self, path): # General Advantage Estimation terminal = np.append(path["terminated"], 0) vpred = np.append(path["vf"], path["next_vf"]) T = len(path["reward"]) path["advantage"] = np.empty(T, 'float32') lastgaelam = 0 for t in reversed(range(T)): nonterminal = 1 - terminal[t + 1] delta = path["reward"][t] + self.gamma * vpred[ t + 1] * nonterminal - vpred[t] path["advantage"][ t] = lastgaelam = delta + self.gamma * self.lam * nonterminal * lastgaelam path["tdlamret"] = (path["advantage"] + path["vf"]).reshape(-1, 1)
class DDQN(BaseAgent): """ Double Deep Q Networks """ name = "DDQN" def __init__(self, env, deep_func, gamma, batch_size, memory_min, memory_max, update_double=10000, train_steps=1000000, log_freq=1000, eps_start=1, eps_decay=-1, eps_min=0.1): super(DDQN, self).__init__() self.env = env self.Q = self.model = deep_func(env) self.target_Q = deep_func(env) self.discount = gamma self.memory_min = memory_min self.memory_max = memory_max self.eps = eps_start self.train_steps = train_steps self.batch_size = batch_size self.done = 0 self.log_freq = log_freq self.progbar = Progbar(self.memory_max) self.memory = ReplayMemory( self.memory_max, ["state", "action", "reward", "next_state", "terminated"]) self.eps_decay = eps_decay if eps_decay == -1: self.eps_decay = 1 / train_steps self.eps_min = eps_min self.update_double = update_double self.actions = [] self.path_generator = self.roller() self.past_rewards = collections.deque([], 50) self.functions = [self.Q] def act(self, state, train=True): if train: if np.random.rand() < self.eps: return np.random.randint(self.env.action_space.n) return np.argmax(self.Q.predict(state)) def train(self): self.progbar.__init__(self.memory_min) while (self.memory.size < self.memory_min): self.path_generator.__next__() while (self.done < self.train_steps): to_log = 0 self.progbar.__init__(self.update_double) old_theta = self.Q.flaten.get() self.target_Q.copy(self.Q) while to_log < self.update_double: self.path_generator.__next__() rollout = self.memory.sample(self.batch_size) state_batch = U.torchify(rollout["state"]) action_batch = U.torchify(rollout["action"]).long() reward_batch = U.torchify(rollout["reward"]) non_final_batch = U.torchify(1 - rollout["terminated"]) next_state_batch = U.torchify(rollout["next_state"]) #current_q = self.Q(state_batch) current_q = self.Q(state_batch).gather( 1, action_batch.unsqueeze(1)).view(-1) _, a_prime = self.Q(next_state_batch).max(1) # Compute the target of the current Q values next_max_q = self.target_Q(next_state_batch).gather( 1, a_prime.unsqueeze(1)).view(-1) target_q = reward_batch + self.discount * non_final_batch * next_max_q.squeeze( ) # Compute loss loss = self.Q.loss(current_q, target_q.detach( )) # loss = self.Q.total_loss(current_q, target_q) # Optimize the model self.Q.optimize(loss, clip=True) self.progbar.add(self.batch_size, values=[("Loss", U.get(loss))]) to_log += self.batch_size self.target_Q.copy(self.Q) new_theta = self.Q.flaten.get() self.log("Delta Theta L1", U.get((new_theta - old_theta).abs().mean())) self.log("Av 50ep rew", np.mean(self.past_rewards)) self.log("Max 50ep rew", np.max(self.past_rewards)) self.log("Min 50ep rew", np.min(self.past_rewards)) self.log("Epsilon", self.eps) self.log("Done", self.done) self.log("Total", self.train_steps) self.target_Q.copy(self.Q) self.print() #self.play() self.save() def set_eps(self, x): self.eps = max(x, self.eps_min) def roller(self): state = self.env.reset() ep_reward = 0 while True: episode = self.memory.empty_episode() for i in range(self.batch_size): # save current state episode["state"].append(state) # act action = self.act(state) self.actions.append(action) state, rew, done, info = self.env.step(action) episode["next_state"].append(state) episode["action"].append(action) episode["reward"].append(rew) episode["terminated"].append(done) ep_reward += rew self.set_eps(self.eps - self.eps_decay) if done: self.past_rewards.append(ep_reward) state = self.env.reset() ep_reward = 0 self.done += 1 if not (self.done) % self.update_double: self.update = True # record the episodes self.memory.record(episode) if self.memory.size < self.memory_min: self.progbar.add(self.batch_size, values=[("Loss", 0.0)]) yield True
class GateTRPO(BaseAgent): name = "GateTRPO" def __init__(self, env, gatepolicy, policy_func, value_func, n_options, option_len=3, timesteps_per_batch=1000, gamma=0.99, lam=0.97, MI_lambda=1e-3, gate_max_kl=1e-2, option_max_kl=1e-2, cg_iters=10, cg_damping=1e-2, vf_iters=2, max_train=1000, ls_step=0.5, checkpoint_freq=50): super(GateTRPO, self).__init__(name=env.name) self.n_options = n_options #self.name = self.name self.env = env self.gamma = gamma self.lam = lam self.MI_lambda = MI_lambda self.current_option = 0 self.timesteps_per_batch = timesteps_per_batch self.gate_max_kl = gate_max_kl self.cg_iters = cg_iters self.cg_damping = cg_damping self.max_train = max_train self.ls_step = ls_step self.checkpoint_freq = checkpoint_freq self.policy = gatepolicy(env.observation_space.shape, env.action_space.n) self.oldpolicy = gatepolicy(env.observation_space.shape, env.action_space.n, verbose=0) self.oldpolicy.disable_grad() self.progbar = Progbar(self.timesteps_per_batch) self.path_generator = self.roller() self.episodes_reward = collections.deque([], 5) self.episodes_len = collections.deque([], 5) self.done = 0 self.functions = [self.policy] self.options = [ OptionTRPO(env.name, i, env, policy_func, value_func, gamma, lam, option_len, option_max_kl, cg_iters, cg_damping, vf_iters, ls_step, self.logger, checkpoint_freq) for i in range(n_options) ] def act(self, state, train=True): if train: return self.policy.sample(state) return self.policy.act(state) def calculate_losses(self, states, options, actions, advantages): RIM = self.KLRIM(states, options, actions) old_pi = RIM["old_log_pi_oia_s"] pi = RIM["old_log_pi_oia_s"] ratio = torch.exp( m_utils.logp(pi, actions) - m_utils.logp(old_pi, actions)) # advantage * pnew / pold surrogate_gain = (ratio * advantages).mean() optimization_gain = surrogate_gain - self.MI_lambda * RIM["MI"] def surr_get(grad=False): Id, pid = RIM["MI_get"](grad) return (torch.exp( m_utils.logp(pid, actions) - m_utils.logp(old_pi, actions)) * advantages).mean() - self.MI_lambda * Id RIM["gain"] = optimization_gain RIM["surr_get"] = surr_get return RIM def train(self): while self.done < self.max_train: print("=" * 40) print(" " * 15, self.done, "\n") self.logger.step() path = self.path_generator.__next__() self.oldpolicy.copy(self.policy) for p in self.options: p.oldpolicy.copy(p.policy) self._train(path) self.logger.display() if not self.done % self.checkpoint_freq: self.save() for p in self.options: p.save() self.done = self.done + 1 self.done = 0 def _train(self, path): states = U.torchify(path["states"]) options = U.torchify(path["options"]).long() actions = U.torchify(path["actions"]).long() advantages = U.torchify(path["baseline"]) tdlamret = U.torchify(path["tdlamret"]) vpred = U.torchify( path["vf"]) # predicted value function before udpate #advantages = (advantages - advantages.mean()) / advantages.std() # standardized advantage function estimate losses = self.calculate_losses(states, options, actions, advantages) kl = losses["gate_meankl"] optimization_gain = losses["gain"] loss_grad = self.policy.flaten.flatgrad(optimization_gain, retain=True) grad_kl = self.policy.flaten.flatgrad(kl, create=True, retain=True) theta_before = self.policy.flaten.get() self.log("Init param sum", theta_before.sum()) self.log("explained variance", (vpred - tdlamret).var() / tdlamret.var()) if np.allclose(loss_grad.detach().cpu().numpy(), 0, atol=1e-19): print("Got zero gradient. not updating") else: with C.timeit("Conjugate Gradient"): stepdir = m_utils.conjugate_gradient(self.Fvp(grad_kl), loss_grad, cg_iters=self.cg_iters) self.log("Conjugate Gradient in s", C.elapsed) assert stepdir.sum() != float("Inf") shs = .5 * stepdir.dot(self.Fvp(grad_kl)(stepdir)) lm = torch.sqrt(shs / self.gate_max_kl) self.log("lagrange multiplier:", lm) self.log("gnorm:", np.linalg.norm(loss_grad.cpu().detach().numpy())) fullstep = stepdir / lm expected_improve = loss_grad.dot(fullstep) surrogate_before = losses["gain"].detach() with C.timeit("Line Search"): stepsize = 1.0 for i in range(10): theta_new = theta_before + fullstep * stepsize self.policy.flaten.set(theta_new) surr = losses["surr_get"]() improve = surr - surrogate_before kl = losses["KL_gate_get"]() if surr == float("Inf") or kl == float("Inf"): C.warning("Infinite value of losses") elif kl > self.gate_max_kl: C.warning("Violated KL") elif improve < 0: stepsize *= self.ls_step else: self.log("Line Search", "OK") break else: improve = 0 self.log("Line Search", "NOPE") self.policy.flaten.set(theta_before) for op in self.options: losses["gain"] = losses["surr_get"](grad=True) op.train(states, options, actions, advantages, tdlamret, losses) surr = losses["surr_get"]() improve = surr - surrogate_before self.log("Expected", expected_improve) self.log("Actual", improve) self.log("Line Search in s", C.elapsed) self.log("LS Steps", i) self.log("KL", kl) self.log("MI", -losses["MI"]) self.log("MI improve", -losses["MI_get"]()[0] + losses["MI"]) self.log("Surrogate", surr) self.log("Gate KL", losses["KL_gate_get"]()) self.log("HRL KL", losses["KL_get"]()) self.log("TDlamret mean", tdlamret.mean()) del (improve, surr, kl) self.log("Last %i rolls mean rew" % len(self.episodes_reward), np.mean(self.episodes_reward)) self.log("Last %i rolls mean len" % len(self.episodes_len), np.mean(self.episodes_len)) del (losses, states, options, actions, advantages, tdlamret, vpred, optimization_gain, loss_grad, grad_kl) for _ in range(10): gc.collect() def roller(self): state = self.env.reset() path = { "states": np.array([state for _ in range(self.timesteps_per_batch)]), "options": np.zeros(self.timesteps_per_batch).astype(int), "actions": np.zeros(self.timesteps_per_batch).astype(int), "rewards": np.zeros(self.timesteps_per_batch), "terminated": np.zeros(self.timesteps_per_batch), "vf": np.zeros(self.timesteps_per_batch) } self.current_option = self.act(state) self.options[self.current_option].select() ep_rews = 0 ep_len = 0 t = 0 done = True rew = 0.0 self.progbar.__init__(self.timesteps_per_batch) while True: if self.options[self.current_option].finished: self.current_option = self.act(state) action = self.options[self.current_option].act(state) vf = self.options[self.current_option].value_function.predict( state) if t > self.timesteps_per_batch - 1: path["next_vf"] = vf * (1 - done * 1.0) self.add_vtarg_and_adv(path) yield path t = 0 self.progbar.__init__(self.timesteps_per_batch) path["states"][t] = state state, rew, done, _ = self.env.step(action) path["options"][t] = self.options[self.current_option].option_n path["actions"][t] = action path["rewards"][t] = rew path["vf"][t] = vf path["terminated"][t] = done * 1.0 ep_rews += rew ep_len += 1 t += 1 self.progbar.add(1) if done: state = self.env.reset() self.episodes_reward.append(ep_rews) self.episodes_len.append(ep_len) ep_rews = 0 ep_len = 0 def add_vtarg_and_adv(self, path): # General Advantage Estimation terminal = np.append(path["terminated"], 0) vpred = np.append(path["vf"], path["next_vf"]) T = len(path["rewards"]) path["advantage"] = np.empty(T, 'float32') lastgaelam = 0 for t in reversed(range(T)): nonterminal = 1 - terminal[t + 1] delta = path["rewards"][t] + self.gamma * vpred[ t + 1] * nonterminal - vpred[t] path["advantage"][ t] = lastgaelam = delta + self.gamma * self.lam * nonterminal * lastgaelam path["tdlamret"] = (path["advantage"] + path["vf"]).reshape(-1, 1) path["baseline"] = (path["advantage"] - np.mean( path["advantage"])) / np.std(path["advantage"]) def Fvp(self, grad_kl): def fisher_product(v): kl_v = (grad_kl * v).sum() grad_grad_kl = self.policy.flaten.flatgrad(kl_v, retain=True) return grad_grad_kl + v * self.cg_damping return fisher_product def KLRIM(self, states, options, actions): """ pg : \pi_g pi_a_so : \pi(a|s,o) pi_oa_s : \pi(o,a|s) pi_o_as : \pi(o|a,s) pi_a_s : \pi(a|s) old : \tilde(\pi) """ old_log_pi_a_so = torch.cat([ p.oldpolicy.logsoftmax(states).unsqueeze(1).detach() for p in self.options ], dim=1) old_log_pg_o_s = self.oldpolicy.logsoftmax(states).detach() old_log_pi_oa_s = old_log_pi_a_so + old_log_pg_o_s.unsqueeze(-1) old_log_pi_a_s = old_log_pi_oa_s.exp().sum(1).log() old_log_pi_oia_s = old_log_pi_oa_s[np.arange(states.shape[0]), options] def calculate_surr(self, states, options, actions, advantages, grad=False): if grad: log_pi_a_so = torch.cat([ p.policy.logsoftmax(states).unsqueeze(1) for p in self.options ], dim=1) log_pg_o_s = self.policy.logsoftmax(states) else: with torch.set_grad_enabled(False): log_pi_a_so = torch.cat([ p.policy.logsoftmax(states).unsqueeze(1) for p in self.options ], dim=1) log_pg_o_s = self.policy.logsoftmax(states) log_pi_oa_s = log_pi_a_so + log_pg_o_s.unsqueeze(-1) log_pi_a_s = log_pi_oa_s.exp().sum(1).log() log_pi_o_as = log_pi_oa_s - log_pi_a_s.unsqueeze(1) H_O_AS = -(log_pi_a_s.exp() * (log_pi_o_as * log_pi_o_as.exp()).sum(1)).sum(-1).mean() H_O = m_utils.entropy_logits(log_pg_o_s).mean() log_pi_o_ais = log_pi_o_as[np.arange(states.shape[0]), :, actions].exp().mean(0).log() log_pi_oi_ais = log_pi_o_as[np.arange(states.shape[0]), options, actions] log_pi_oia_s = log_pi_oa_s[np.arange(states.shape[0]), options] MI = m_utils.entropy_logits(log_pi_o_ais) - m_utils.entropy_logits( log_pi_oi_ais) ratio = torch.exp( m_utils.logp(pi, actions) - m_utils.logp(old_pi, actions)) surrogate_gain = (ratio * advantages).mean() optimization_gain = surrogate_gain - self.MI_lambda * MI # def surr_get(self,grad=False): # Id,pid = RIM["MI_get"](grad) # return (torch.exp(m_utils.logp(pid,actions) - m_utils.logp(old_pi,actions))*advantages).mean() - self.MI_lambda*Id # # RIM["gain"] = optimization_gain # RIM["surr_get"] = surr_get # return RIM # # return MI # # log_pi_a_so = torch.cat([p.policy.logsoftmax(states).unsqueeze(1) for p in self.options],dim=1) # log_pg_o_s = self.policy.logsoftmax(states) # log_pi_oa_s = log_pi_a_so+log_pg_o_s.unsqueeze(-1) # log_pi_a_s = log_pi_oa_s.exp().sum(1).log() # log_pi_o_as = log_pi_oa_s - log_pi_a_s.unsqueeze(1) # # # log_pi_o_ais = log_pi_o_as[np.arange(states.shape[0]),:,actions].exp().mean(0).log() # log_pi_oi_ais = log_pi_o_as[np.arange(states.shape[0]),options,actions] def mean_HKL(self, states, old_log_pi_a_s, grad=False): if grad: log_pi_a_so = torch.cat([ p.policy.logsoftmax(states).unsqueeze(1) for p in self.options ], dim=1) log_pg_o_s = self.policy.logsoftmax(states) else: log_pi_a_so = torch.cat([ p.policy.logsoftmax(states).detach().unsqueeze(1) for p in self.options ], dim=1) log_pg_o_s = self.policy.logsoftmax(states).detach() log_pi_a_s = (log_pi_a_so + log_pg_o_s.unsqueeze(-1)).exp().sum(1).log() mean_kl_new_old = m_utils.kl_logits(old_log_pi_a_s, log_pi_a_s).mean() return mean_kl_new_old def mean_KL_gate(self, states, old_log_pg_o_s, grad=False): if grad: log_pg_o_s = self.policy.logsoftmax(states) else: log_pg_o_s = self.policy.logsoftmax(states).detach() return m_utils.kl_logits(old_log_pg_o_s, log_pg_o_s).mean() def load(self): super(GateTRPO, self).load() for p in self.options: p.load()