class GAIL: def __init__(self, vail_sample, reward_shift, reward_aug, gae_norm, global_norm, actor_lr, critic_lr, disc_lr, actor_units, critic_units, disc_units, disc_reduce_units, gamma, lambd, clip, entropy, epochs, batch_size, update_rate, data_dir, demo_list): # build network self.actor = Actor(lr=actor_lr, hidden_units=actor_units) self.critic = Critic(lr=critic_lr, hidden_units=critic_units) self.discriminator = Discriminator(lr=disc_lr, hidden_units=disc_units, reduce_units=disc_reduce_units) self.encoder = VAE_Encoder(latent_num=64) # set hyperparameters self.vail_sample = vail_sample self.reward_shift = reward_shift self.reward_aug = reward_aug self.gae_norm = gae_norm self.gamma = gamma self.lambd = lambd self.gam_lam = gamma * lambd self.clip = clip self.entropy = entropy self.epochs = epochs self.batch_size = batch_size self.half_batch_size = batch_size // 2 self.update_rate = update_rate self.grad_global_norm = global_norm self.beta = BETA_INIT # build memory self.memory = HorizonMemory(use_reward=reward_aug) self.replay = ReplayMemory() # build expert demonstration Pipeline self.data_dir = data_dir self.demo_list = os.listdir(data_dir) self.demo_group_num = 500 self.demo_rotate = 5 assert len(demo_list) >= self.demo_group_num self.set_demo() # ready self.dummy_forward() self.actor_vars = self.actor.trainable_variables + self.encoder.trainable_variables self.critic_vars = self.critic.trainable_variables + self.encoder.trainable_variables self.disc_vars = self.discriminator.trainable_variables + self.encoder.trainable_variables def dummy_forward(self): # connect networks dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32) dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32) self.encoder(dummy_state) self.actor(self.encoder, dummy_state) self.critic(self.encoder, dummy_state) self.discriminator(self.encoder, dummy_state, dummy_action) def set_demo(self): self.demo_list = os.listdir(data_dir) selected_demos = random.sample(self.demo_list, self.demo_group_num) expert_states = [] expert_actions = [] for demo_name in selected_demos: demo = np.load(self.data_dir + demo_name) states = demo['state'] actions = demo['action'] expert_states.append(states) expert_actions.append(actions) self.expert_states = np.concatenate(expert_states, axis=0) self.expert_actions = np.concatenate(expert_actions, axis=0) del demo def get_demonstration(self, sample_num): index = np.arange(len(self.expert_states)) try: assert len(self.expert_states) >= sample_num except Exception: self.set_demo() np.random.shuffle(index) index = index[:sample_num] return self.expert_states[index], self.expert_actions[index] def memory_process(self, next_state, done): # [[(1,64,64,3)], [], ...], [[(1,2),(1,9),(1,3),(1,4)], [], ...], [[c_pi, d_pi, s_pi, a_pi], [], ...] if self.reward_aug: states, actions, log_old_pis, rewards = self.memory.rollout() else: states, actions, log_old_pis = self.memory.rollout() np_states = np.concatenate(states + [next_state], axis=0) np_actions = np.concatenate(actions, axis=0) np_rewards = self.get_reward(np_states[:-1], np_actions) # (N, 1) if self.reward_aug: np_env_rewards = np.stack(rewards, axis=0).reshape(-1, 1) np_rewards = np_rewards + np_env_rewards gae, oracle = self.get_gae_oracle(np_states, np_rewards, done) # (N, 1), (N, 1) self.replay.append(states, actions, log_old_pis, gae, oracle) self.memory.flush() if len(self.replay) >= self.update_rate: self.update() self.replay.flush() def get_action(self, state): policy = self.actor(self.encoder, state).numpy()[0] action = np.random.choice(ACTION_NUM, p=policy) # action = np.argmax(policy) action_one_hot = np.eye(ACTION_NUM, dtype=np.float32)[[action]] # (1, 4) log_old_pi = [[np.log(policy[action] + 1e-8)]] # (1, 1) return action, action_one_hot, log_old_pi, policy def get_reward(self, states, actions): d = self.discriminator(self.encoder, states, actions).numpy() # (N, 1) # rewards = 0.5 - d # linear reward # rewards = np.tan(0.5 - d) # tan reward if self.reward_shift: rewards = -np.log(2.0 * d + 1e-8) # log equil reward else: rewards = -np.log(d + 1e-8) # log reward # rewards = 0.1 * np.where(rewards>1, 1, rewards) return rewards def get_gae_oracle(self, states, rewards, done): # states include next state values = self.critic(self.encoder, states).numpy() # (N+1, 1) if done: values[-1] = np.float32([0]) N = len(rewards) gae = 0 gaes = np.zeros((N, 1), dtype=np.float32) oracles = np.zeros((N, 1), dtype=np.float32) for t in reversed(range(N)): oracles[t] = rewards[t] + self.gamma * values[t + 1] delta = oracles[t] - values[t] gae = delta + self.gam_lam * gae gaes[t][0] = gae # oracles = gaes + values[:-1] # (N, 1) if self.gae_norm: gaes = (gaes - np.mean(gaes)) / (np.std(gaes) + 1e-8) return gaes, oracles def update(self): # load & calculate data states, actions, log_old_pis, gaes, oracles \ = self.replay.rollout() states = np.concatenate(states, axis=0) actions = np.concatenate(actions, axis=0) log_old_pis = np.concatenate(log_old_pis, axis=0) gaes = np.concatenate(gaes, axis=0) oracles = np.concatenate(oracles, axis=0) N = len(states) # update discriminator # load expert demonstration s_e, a_e = self.get_demonstration(N) batch_num = N // self.half_batch_size index = np.arange(N) np.random.shuffle(index) for i in range(batch_num): idx = index[i * self.half_batch_size:(i + 1) * self.half_batch_size] s_concat = np.concatenate([states[idx], s_e[idx]], axis=0) a_concat = np.concatenate([actions[idx], a_e[idx]], axis=0) with tf.GradientTape(persistent=True) as tape: mu, std, sampled = self.discriminator.encode( self.encoder, s_concat, a_concat) discs = self.discriminator.decode( sampled if self.vail_sample else mu) kld_loss = tf.reduce_mean(tf_gaussian_KL(mu, 0, std, 1)) agent_loss = -tf.reduce_mean( tf.math.log(discs[:self.half_batch_size] + 1e-8)) expert_loss = -tf.reduce_mean( tf.math.log(1 + 1e-8 - discs[self.half_batch_size:])) disc_loss = agent_loss + expert_loss discriminator_loss = disc_loss + self.beta * kld_loss disc_grads = tape.gradient(discriminator_loss, self.disc_vars) if self.grad_global_norm > 0: disc_grads, _ = tf.clip_by_global_norm(disc_grads, self.grad_global_norm) self.discriminator.opt.apply_gradients( zip(disc_grads, self.disc_vars)) del tape # TODO: update posterior # L1 loss = logQ(code|s,prev_a,prev_code) # update actor & critic # batch_num = math.ceil(len(states) / self.batch_size) batch_num = len(gaes) // self.batch_size index = np.arange(len(gaes)) for _ in range(self.epochs): np.random.shuffle(index) for i in range(batch_num): # if i == batch_num - 1: # idx = index[i*self.batch_size : ] # else: idx = index[i * self.batch_size:(i + 1) * self.batch_size] state = states[idx] action = actions[idx] log_old_pi = log_old_pis[idx] gae = gaes[idx] oracle = oracles[idx] # update critic with tf.GradientTape(persistent=True) as tape: values = self.critic(self.encoder, state) # (N, 1) critic_loss = tf.reduce_mean( (oracle - values)**2) # MSE loss critic_grads = tape.gradient(critic_loss, self.critic_vars) if self.grad_global_norm > 0: critic_grads, _ = tf.clip_by_global_norm( critic_grads, self.grad_global_norm) self.critic.opt.apply_gradients( zip(critic_grads, self.critic_vars)) del tape # update actor with tf.GradientTape(persistent=True) as tape: pred_action = self.actor(self.encoder, state) # RL (PPO) term log_pi = tf.expand_dims(tf.math.log( tf.reduce_sum(pred_action * action, axis=1) + 1e-8), axis=1) # (N, 1) ratio = tf.exp(log_pi - log_old_pi) clip_ratio = tf.clip_by_value(ratio, 1 - self.clip, 1 + self.clip) clip_loss = -tf.reduce_mean( tf.minimum(ratio * gae, clip_ratio * gae)) entropy = tf.reduce_mean(tf.exp(log_pi) * log_pi) actor_loss = clip_loss + self.entropy * entropy actor_grads = tape.gradient( actor_loss, self.actor_vars) # NOTE: freeze posterior if self.grad_global_norm > 0: actor_grads, _ = tf.clip_by_global_norm( actor_grads, self.grad_global_norm) self.actor.opt.apply_gradients( zip(actor_grads, self.actor_vars)) del tape # print('%d samples trained... D loss: %.4f C loss: %.4f A loss: %.4f\t\t\t' # % (len(gaes), disc_loss, critic_loss, actor_loss), end='\r') def save_model(self, dir, tag=''): self.actor.save_weights(dir + tag + 'actor.h5') self.critic.save_weights(dir + tag + 'critic.h5') self.discriminator.save_weights(dir + tag + 'discriminator.h5') self.encoder.save_weights(dir + tag + 'encoder.h5') def load_model(self, dir, tag=''): if os.path.exists(dir + tag + 'actor.h5'): self.actor.load_weights(dir + tag + 'actor.h5') print('Actor loaded... %s%sactor.h5' % (dir, tag)) if os.path.exists(dir + tag + 'critic.h5'): self.critic.load_weights(dir + tag + 'critic.h5') print('Critic loaded... %s%scritic.h5' % (dir, tag)) if os.path.exists(dir + tag + 'discriminator.h5'): self.discriminator.load_weights(dir + tag + 'discriminator.h5') print('Discriminator loaded... %s%sdiscriminator.h5' % (dir, tag)) if os.path.exists(dir + tag + 'encoder.h5'): self.encoder.load_weights(dir + tag + 'encoder.h5') print('encoder loaded... %s%sencoder.h5' % (dir, tag)) def load_encoder(self, dir, tag=''): if os.path.exists(dir + tag + 'encoder.h5'): self.encoder.load_weights(dir + tag + 'encoder.h5') print('encoder loaded... %s%sencoder.h5' % (dir, tag))
class Agent(): def __init__(self, algo, optimizer, env, num_actions, memory_size=10000): self.algo = algo # currently not used. Future work to make learning algorithms modular. self.env = env self.policy = Feedforward(env.observation_space.shape[0], env.action_space.n) # Reusing ReplayMemory class for Q-learning, but this may be a bit confusing so let me clarify (I'm just lazy here) # As policy gradient is on-policy method, you can not use experiences generated from a different policy # I'm flushing the memory at the beggining of every epoch (see train function) so the policy is updated # only on the trajectories from the current policy self.memory = ReplayMemory(memory_size) self.n_actions = env.action_space.n if optimizer == 'Adam': self.optim = optim.Adam(self.policy.parameters()) else: raise NotImplementedError def update_policy(self, gamma): memory = self.memory.get_memory() episode_length = len(memory) memory = list(zip(*memory)) discounted_rewards = [] s = memory[0] a = memory[1] ns = memory[2] r = memory[3] a_one_hot = torch.nn.functional.one_hot(torch.tensor(a), num_classes=2).float() for t in range(episode_length): r_forward = r[t:] G = 0 # Compute a discounted cummulative reward from time step t for i, reward in enumerate(r_forward): G += gamma**i * reward discounted_rewards.append(G) rewards_t = torch.tensor(discounted_rewards).detach() s_t = torch.tensor(s).float() selected_action_probs = self.policy(s_t) prob = torch.sum(selected_action_probs * a_one_hot, axis=1) # A small hack to prevent inf when log(0) clipped = torch.clamp(prob, min=1e-10, max=1.0) J = -torch.log(clipped) * rewards_t grad = torch.sum(J) self.optim.zero_grad() grad.backward() self.optim.step() def train(self, num_epochs, gamma): # initialization for e in range(num_epochs): self.memory.flush() state = self.env.reset() done = False total_reward = 0 self.env.render() t = 0 while not done: action_prob = self.policy(torch.from_numpy(state).float()) #print(action_prob.detach().numpy()) #action = torch.argmax(action_prob).item() action = np.random.choice(range(self.n_actions), p=action_prob.detach().numpy()) next_state, reward, done, _ = self.env.step(action) self.memory.push(state, action, next_state, reward) self.env.render() total_reward += reward state = next_state t += 1 print("Episode: ", e) print("Reward: ", total_reward) writer.add_scalar('./runs/rewards', total_reward, e) self.update_policy(gamma) self.env.close() writer.close()