class GAIL: def __init__(self, vail_sample, reward_shift, reward_aug, gae_norm, global_norm, actor_lr, critic_lr, disc_lr, actor_units, critic_units, disc_units, disc_reduce_units, gamma, lambd, clip, entropy, epochs, batch_size, update_rate, data_dir, demo_list): # build network self.actor = Actor(lr=actor_lr, hidden_units=actor_units) self.critic = Critic(lr=critic_lr, hidden_units=critic_units) self.discriminator = Discriminator(lr=disc_lr, hidden_units=disc_units, reduce_units=disc_reduce_units) self.encoder = VAE_Encoder(latent_num=64) # set hyperparameters self.vail_sample = vail_sample self.reward_shift = reward_shift self.reward_aug = reward_aug self.gae_norm = gae_norm self.gamma = gamma self.lambd = lambd self.gam_lam = gamma * lambd self.clip = clip self.entropy = entropy self.epochs = epochs self.batch_size = batch_size self.half_batch_size = batch_size // 2 self.update_rate = update_rate self.grad_global_norm = global_norm self.beta = BETA_INIT # build memory self.memory = HorizonMemory(use_reward=reward_aug) self.replay = ReplayMemory() # build expert demonstration Pipeline self.data_dir = data_dir self.demo_list = os.listdir(data_dir) self.demo_group_num = 500 self.demo_rotate = 5 assert len(demo_list) >= self.demo_group_num self.set_demo() # ready self.dummy_forward() self.actor_vars = self.actor.trainable_variables + self.encoder.trainable_variables self.critic_vars = self.critic.trainable_variables + self.encoder.trainable_variables self.disc_vars = self.discriminator.trainable_variables + self.encoder.trainable_variables def dummy_forward(self): # connect networks dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32) dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32) self.encoder(dummy_state) self.actor(self.encoder, dummy_state) self.critic(self.encoder, dummy_state) self.discriminator(self.encoder, dummy_state, dummy_action) def set_demo(self): self.demo_list = os.listdir(data_dir) selected_demos = random.sample(self.demo_list, self.demo_group_num) expert_states = [] expert_actions = [] for demo_name in selected_demos: demo = np.load(self.data_dir + demo_name) states = demo['state'] actions = demo['action'] expert_states.append(states) expert_actions.append(actions) self.expert_states = np.concatenate(expert_states, axis=0) self.expert_actions = np.concatenate(expert_actions, axis=0) del demo def get_demonstration(self, sample_num): index = np.arange(len(self.expert_states)) try: assert len(self.expert_states) >= sample_num except Exception: self.set_demo() np.random.shuffle(index) index = index[:sample_num] return self.expert_states[index], self.expert_actions[index] def memory_process(self, next_state, done): # [[(1,64,64,3)], [], ...], [[(1,2),(1,9),(1,3),(1,4)], [], ...], [[c_pi, d_pi, s_pi, a_pi], [], ...] if self.reward_aug: states, actions, log_old_pis, rewards = self.memory.rollout() else: states, actions, log_old_pis = self.memory.rollout() np_states = np.concatenate(states + [next_state], axis=0) np_actions = np.concatenate(actions, axis=0) np_rewards = self.get_reward(np_states[:-1], np_actions) # (N, 1) if self.reward_aug: np_env_rewards = np.stack(rewards, axis=0).reshape(-1, 1) np_rewards = np_rewards + np_env_rewards gae, oracle = self.get_gae_oracle(np_states, np_rewards, done) # (N, 1), (N, 1) self.replay.append(states, actions, log_old_pis, gae, oracle) self.memory.flush() if len(self.replay) >= self.update_rate: self.update() self.replay.flush() def get_action(self, state): policy = self.actor(self.encoder, state).numpy()[0] action = np.random.choice(ACTION_NUM, p=policy) # action = np.argmax(policy) action_one_hot = np.eye(ACTION_NUM, dtype=np.float32)[[action]] # (1, 4) log_old_pi = [[np.log(policy[action] + 1e-8)]] # (1, 1) return action, action_one_hot, log_old_pi, policy def get_reward(self, states, actions): d = self.discriminator(self.encoder, states, actions).numpy() # (N, 1) # rewards = 0.5 - d # linear reward # rewards = np.tan(0.5 - d) # tan reward if self.reward_shift: rewards = -np.log(2.0 * d + 1e-8) # log equil reward else: rewards = -np.log(d + 1e-8) # log reward # rewards = 0.1 * np.where(rewards>1, 1, rewards) return rewards def get_gae_oracle(self, states, rewards, done): # states include next state values = self.critic(self.encoder, states).numpy() # (N+1, 1) if done: values[-1] = np.float32([0]) N = len(rewards) gae = 0 gaes = np.zeros((N, 1), dtype=np.float32) oracles = np.zeros((N, 1), dtype=np.float32) for t in reversed(range(N)): oracles[t] = rewards[t] + self.gamma * values[t + 1] delta = oracles[t] - values[t] gae = delta + self.gam_lam * gae gaes[t][0] = gae # oracles = gaes + values[:-1] # (N, 1) if self.gae_norm: gaes = (gaes - np.mean(gaes)) / (np.std(gaes) + 1e-8) return gaes, oracles def update(self): # load & calculate data states, actions, log_old_pis, gaes, oracles \ = self.replay.rollout() states = np.concatenate(states, axis=0) actions = np.concatenate(actions, axis=0) log_old_pis = np.concatenate(log_old_pis, axis=0) gaes = np.concatenate(gaes, axis=0) oracles = np.concatenate(oracles, axis=0) N = len(states) # update discriminator # load expert demonstration s_e, a_e = self.get_demonstration(N) batch_num = N // self.half_batch_size index = np.arange(N) np.random.shuffle(index) for i in range(batch_num): idx = index[i * self.half_batch_size:(i + 1) * self.half_batch_size] s_concat = np.concatenate([states[idx], s_e[idx]], axis=0) a_concat = np.concatenate([actions[idx], a_e[idx]], axis=0) with tf.GradientTape(persistent=True) as tape: mu, std, sampled = self.discriminator.encode( self.encoder, s_concat, a_concat) discs = self.discriminator.decode( sampled if self.vail_sample else mu) kld_loss = tf.reduce_mean(tf_gaussian_KL(mu, 0, std, 1)) agent_loss = -tf.reduce_mean( tf.math.log(discs[:self.half_batch_size] + 1e-8)) expert_loss = -tf.reduce_mean( tf.math.log(1 + 1e-8 - discs[self.half_batch_size:])) disc_loss = agent_loss + expert_loss discriminator_loss = disc_loss + self.beta * kld_loss disc_grads = tape.gradient(discriminator_loss, self.disc_vars) if self.grad_global_norm > 0: disc_grads, _ = tf.clip_by_global_norm(disc_grads, self.grad_global_norm) self.discriminator.opt.apply_gradients( zip(disc_grads, self.disc_vars)) del tape # TODO: update posterior # L1 loss = logQ(code|s,prev_a,prev_code) # update actor & critic # batch_num = math.ceil(len(states) / self.batch_size) batch_num = len(gaes) // self.batch_size index = np.arange(len(gaes)) for _ in range(self.epochs): np.random.shuffle(index) for i in range(batch_num): # if i == batch_num - 1: # idx = index[i*self.batch_size : ] # else: idx = index[i * self.batch_size:(i + 1) * self.batch_size] state = states[idx] action = actions[idx] log_old_pi = log_old_pis[idx] gae = gaes[idx] oracle = oracles[idx] # update critic with tf.GradientTape(persistent=True) as tape: values = self.critic(self.encoder, state) # (N, 1) critic_loss = tf.reduce_mean( (oracle - values)**2) # MSE loss critic_grads = tape.gradient(critic_loss, self.critic_vars) if self.grad_global_norm > 0: critic_grads, _ = tf.clip_by_global_norm( critic_grads, self.grad_global_norm) self.critic.opt.apply_gradients( zip(critic_grads, self.critic_vars)) del tape # update actor with tf.GradientTape(persistent=True) as tape: pred_action = self.actor(self.encoder, state) # RL (PPO) term log_pi = tf.expand_dims(tf.math.log( tf.reduce_sum(pred_action * action, axis=1) + 1e-8), axis=1) # (N, 1) ratio = tf.exp(log_pi - log_old_pi) clip_ratio = tf.clip_by_value(ratio, 1 - self.clip, 1 + self.clip) clip_loss = -tf.reduce_mean( tf.minimum(ratio * gae, clip_ratio * gae)) entropy = tf.reduce_mean(tf.exp(log_pi) * log_pi) actor_loss = clip_loss + self.entropy * entropy actor_grads = tape.gradient( actor_loss, self.actor_vars) # NOTE: freeze posterior if self.grad_global_norm > 0: actor_grads, _ = tf.clip_by_global_norm( actor_grads, self.grad_global_norm) self.actor.opt.apply_gradients( zip(actor_grads, self.actor_vars)) del tape # print('%d samples trained... D loss: %.4f C loss: %.4f A loss: %.4f\t\t\t' # % (len(gaes), disc_loss, critic_loss, actor_loss), end='\r') def save_model(self, dir, tag=''): self.actor.save_weights(dir + tag + 'actor.h5') self.critic.save_weights(dir + tag + 'critic.h5') self.discriminator.save_weights(dir + tag + 'discriminator.h5') self.encoder.save_weights(dir + tag + 'encoder.h5') def load_model(self, dir, tag=''): if os.path.exists(dir + tag + 'actor.h5'): self.actor.load_weights(dir + tag + 'actor.h5') print('Actor loaded... %s%sactor.h5' % (dir, tag)) if os.path.exists(dir + tag + 'critic.h5'): self.critic.load_weights(dir + tag + 'critic.h5') print('Critic loaded... %s%scritic.h5' % (dir, tag)) if os.path.exists(dir + tag + 'discriminator.h5'): self.discriminator.load_weights(dir + tag + 'discriminator.h5') print('Discriminator loaded... %s%sdiscriminator.h5' % (dir, tag)) if os.path.exists(dir + tag + 'encoder.h5'): self.encoder.load_weights(dir + tag + 'encoder.h5') print('encoder loaded... %s%sencoder.h5' % (dir, tag)) def load_encoder(self, dir, tag=''): if os.path.exists(dir + tag + 'encoder.h5'): self.encoder.load_weights(dir + tag + 'encoder.h5') print('encoder loaded... %s%sencoder.h5' % (dir, tag))
eval_net.eval() values = eval_net(*state) # print(values) # print(state) if random.random() > epsilon: prob_uniform = (values > -9999999).float() dist = Categorical(prob_uniform) action = dist.sample() else: action = torch.argmax(values, dim=1) next_state, reward, done = envs.step(action.cpu().numpy()) for idx in range(state[0].size(0)): memory.append([state[0][[idx]].cpu(), state[1][[idx]].cpu(), action[[idx]].cpu(), reward[[idx]].cpu(), next_state[0][[idx]].cpu(), next_state[1][[idx]].cpu(), done[[idx]].cpu()] ) state = next_state if len(memory) >= memory.capacity: learn(memory, eval_net, target_net, learn_step_counter, args.double_dqn) learn_step_counter += 1 if _i % 50 == 0: ret = evaluate_batch(eval_net, 0, 100) ret = ret[0] performance.append(ret) time_now = time.time() print('average performance on{} {}{} {}: {:.4f}, time: {:.4f}, step: {}'.format(args.fn, args.func_type,args.package_num, _i, performance[-1], time_now - time_last, learn_step_counter//q_network_iteration))