コード例 #1
0
ファイル: vail_nf.py プロジェクト: sunghoonhong/VI-GAIL
class GAIL:
    def __init__(self, vail_sample, reward_shift, reward_aug, gae_norm,
                 global_norm, actor_lr, critic_lr, disc_lr, actor_units,
                 critic_units, disc_units, disc_reduce_units, gamma, lambd,
                 clip, entropy, epochs, batch_size, update_rate, data_dir,
                 demo_list):
        # build network
        self.actor = Actor(lr=actor_lr, hidden_units=actor_units)
        self.critic = Critic(lr=critic_lr, hidden_units=critic_units)
        self.discriminator = Discriminator(lr=disc_lr,
                                           hidden_units=disc_units,
                                           reduce_units=disc_reduce_units)
        self.encoder = VAE_Encoder(latent_num=64)

        # set hyperparameters
        self.vail_sample = vail_sample
        self.reward_shift = reward_shift
        self.reward_aug = reward_aug
        self.gae_norm = gae_norm
        self.gamma = gamma
        self.lambd = lambd
        self.gam_lam = gamma * lambd
        self.clip = clip
        self.entropy = entropy
        self.epochs = epochs
        self.batch_size = batch_size
        self.half_batch_size = batch_size // 2
        self.update_rate = update_rate
        self.grad_global_norm = global_norm
        self.beta = BETA_INIT

        # build memory
        self.memory = HorizonMemory(use_reward=reward_aug)
        self.replay = ReplayMemory()

        # build expert demonstration Pipeline
        self.data_dir = data_dir
        self.demo_list = os.listdir(data_dir)
        self.demo_group_num = 500
        self.demo_rotate = 5
        assert len(demo_list) >= self.demo_group_num
        self.set_demo()

        # ready
        self.dummy_forward()
        self.actor_vars = self.actor.trainable_variables + self.encoder.trainable_variables
        self.critic_vars = self.critic.trainable_variables + self.encoder.trainable_variables
        self.disc_vars = self.discriminator.trainable_variables + self.encoder.trainable_variables

    def dummy_forward(self):
        # connect networks
        dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32)
        dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32)
        self.encoder(dummy_state)
        self.actor(self.encoder, dummy_state)
        self.critic(self.encoder, dummy_state)
        self.discriminator(self.encoder, dummy_state, dummy_action)

    def set_demo(self):
        self.demo_list = os.listdir(data_dir)
        selected_demos = random.sample(self.demo_list, self.demo_group_num)

        expert_states = []
        expert_actions = []
        for demo_name in selected_demos:
            demo = np.load(self.data_dir + demo_name)
            states = demo['state']
            actions = demo['action']

            expert_states.append(states)
            expert_actions.append(actions)
        self.expert_states = np.concatenate(expert_states, axis=0)
        self.expert_actions = np.concatenate(expert_actions, axis=0)
        del demo

    def get_demonstration(self, sample_num):
        index = np.arange(len(self.expert_states))
        try:
            assert len(self.expert_states) >= sample_num
        except Exception:
            self.set_demo()
        np.random.shuffle(index)
        index = index[:sample_num]
        return self.expert_states[index], self.expert_actions[index]

    def memory_process(self, next_state, done):
        # [[(1,64,64,3)], [], ...], [[(1,2),(1,9),(1,3),(1,4)], [], ...], [[c_pi, d_pi, s_pi, a_pi], [], ...]
        if self.reward_aug:
            states, actions, log_old_pis, rewards = self.memory.rollout()
        else:
            states, actions, log_old_pis = self.memory.rollout()
        np_states = np.concatenate(states + [next_state], axis=0)
        np_actions = np.concatenate(actions, axis=0)

        np_rewards = self.get_reward(np_states[:-1], np_actions)  # (N, 1)
        if self.reward_aug:
            np_env_rewards = np.stack(rewards, axis=0).reshape(-1, 1)
            np_rewards = np_rewards + np_env_rewards
        gae, oracle = self.get_gae_oracle(np_states, np_rewards,
                                          done)  # (N, 1), (N, 1)
        self.replay.append(states, actions, log_old_pis, gae, oracle)
        self.memory.flush()
        if len(self.replay) >= self.update_rate:
            self.update()
            self.replay.flush()

    def get_action(self, state):
        policy = self.actor(self.encoder, state).numpy()[0]
        action = np.random.choice(ACTION_NUM, p=policy)
        # action = np.argmax(policy)
        action_one_hot = np.eye(ACTION_NUM,
                                dtype=np.float32)[[action]]  # (1, 4)
        log_old_pi = [[np.log(policy[action] + 1e-8)]]  # (1, 1)
        return action, action_one_hot, log_old_pi, policy

    def get_reward(self, states, actions):
        d = self.discriminator(self.encoder, states, actions).numpy()  # (N, 1)
        # rewards = 0.5 - d       # linear reward
        # rewards = np.tan(0.5 - d)     # tan reward
        if self.reward_shift:
            rewards = -np.log(2.0 * d + 1e-8)  # log equil reward
        else:
            rewards = -np.log(d + 1e-8)  # log reward
        # rewards = 0.1 * np.where(rewards>1, 1, rewards)
        return rewards

    def get_gae_oracle(self, states, rewards, done):
        # states include next state
        values = self.critic(self.encoder, states).numpy()  # (N+1, 1)
        if done:
            values[-1] = np.float32([0])
        N = len(rewards)
        gae = 0
        gaes = np.zeros((N, 1), dtype=np.float32)
        oracles = np.zeros((N, 1), dtype=np.float32)
        for t in reversed(range(N)):
            oracles[t] = rewards[t] + self.gamma * values[t + 1]
            delta = oracles[t] - values[t]
            gae = delta + self.gam_lam * gae
            gaes[t][0] = gae

        # oracles = gaes + values[:-1]        # (N, 1)
        if self.gae_norm:
            gaes = (gaes - np.mean(gaes)) / (np.std(gaes) + 1e-8)
        return gaes, oracles

    def update(self):
        # load & calculate data
        states, actions, log_old_pis, gaes, oracles \
            = self.replay.rollout()

        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        log_old_pis = np.concatenate(log_old_pis, axis=0)
        gaes = np.concatenate(gaes, axis=0)
        oracles = np.concatenate(oracles, axis=0)
        N = len(states)
        # update discriminator
        # load expert demonstration
        s_e, a_e = self.get_demonstration(N)

        batch_num = N // self.half_batch_size
        index = np.arange(N)
        np.random.shuffle(index)
        for i in range(batch_num):
            idx = index[i * self.half_batch_size:(i + 1) *
                        self.half_batch_size]
            s_concat = np.concatenate([states[idx], s_e[idx]], axis=0)
            a_concat = np.concatenate([actions[idx], a_e[idx]], axis=0)

            with tf.GradientTape(persistent=True) as tape:
                mu, std, sampled = self.discriminator.encode(
                    self.encoder, s_concat, a_concat)

                discs = self.discriminator.decode(
                    sampled if self.vail_sample else mu)
                kld_loss = tf.reduce_mean(tf_gaussian_KL(mu, 0, std, 1))
                agent_loss = -tf.reduce_mean(
                    tf.math.log(discs[:self.half_batch_size] + 1e-8))
                expert_loss = -tf.reduce_mean(
                    tf.math.log(1 + 1e-8 - discs[self.half_batch_size:]))
                disc_loss = agent_loss + expert_loss
                discriminator_loss = disc_loss + self.beta * kld_loss
            disc_grads = tape.gradient(discriminator_loss, self.disc_vars)
            if self.grad_global_norm > 0:
                disc_grads, _ = tf.clip_by_global_norm(disc_grads,
                                                       self.grad_global_norm)
            self.discriminator.opt.apply_gradients(
                zip(disc_grads, self.disc_vars))
            del tape

        # TODO: update posterior
        # L1 loss = logQ(code|s,prev_a,prev_code)
        # update actor & critic
        # batch_num = math.ceil(len(states) / self.batch_size)
        batch_num = len(gaes) // self.batch_size
        index = np.arange(len(gaes))
        for _ in range(self.epochs):
            np.random.shuffle(index)
            for i in range(batch_num):
                # if i == batch_num - 1:
                #     idx = index[i*self.batch_size : ]
                # else:
                idx = index[i * self.batch_size:(i + 1) * self.batch_size]
                state = states[idx]
                action = actions[idx]
                log_old_pi = log_old_pis[idx]
                gae = gaes[idx]
                oracle = oracles[idx]

                # update critic
                with tf.GradientTape(persistent=True) as tape:
                    values = self.critic(self.encoder, state)  # (N, 1)
                    critic_loss = tf.reduce_mean(
                        (oracle - values)**2)  # MSE loss
                critic_grads = tape.gradient(critic_loss, self.critic_vars)
                if self.grad_global_norm > 0:
                    critic_grads, _ = tf.clip_by_global_norm(
                        critic_grads, self.grad_global_norm)
                self.critic.opt.apply_gradients(
                    zip(critic_grads, self.critic_vars))
                del tape

                # update actor
                with tf.GradientTape(persistent=True) as tape:
                    pred_action = self.actor(self.encoder, state)

                    # RL (PPO) term
                    log_pi = tf.expand_dims(tf.math.log(
                        tf.reduce_sum(pred_action * action, axis=1) + 1e-8),
                                            axis=1)  # (N, 1)
                    ratio = tf.exp(log_pi - log_old_pi)
                    clip_ratio = tf.clip_by_value(ratio, 1 - self.clip,
                                                  1 + self.clip)
                    clip_loss = -tf.reduce_mean(
                        tf.minimum(ratio * gae, clip_ratio * gae))
                    entropy = tf.reduce_mean(tf.exp(log_pi) * log_pi)
                    actor_loss = clip_loss + self.entropy * entropy

                actor_grads = tape.gradient(
                    actor_loss, self.actor_vars)  # NOTE: freeze posterior
                if self.grad_global_norm > 0:
                    actor_grads, _ = tf.clip_by_global_norm(
                        actor_grads, self.grad_global_norm)
                self.actor.opt.apply_gradients(
                    zip(actor_grads, self.actor_vars))

                del tape
            # print('%d samples trained... D loss: %.4f C loss: %.4f A loss: %.4f\t\t\t'
            #     % (len(gaes), disc_loss, critic_loss, actor_loss), end='\r')

    def save_model(self, dir, tag=''):
        self.actor.save_weights(dir + tag + 'actor.h5')
        self.critic.save_weights(dir + tag + 'critic.h5')
        self.discriminator.save_weights(dir + tag + 'discriminator.h5')
        self.encoder.save_weights(dir + tag + 'encoder.h5')

    def load_model(self, dir, tag=''):
        if os.path.exists(dir + tag + 'actor.h5'):
            self.actor.load_weights(dir + tag + 'actor.h5')
            print('Actor loaded... %s%sactor.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'critic.h5'):
            self.critic.load_weights(dir + tag + 'critic.h5')
            print('Critic loaded... %s%scritic.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'discriminator.h5'):
            self.discriminator.load_weights(dir + tag + 'discriminator.h5')
            print('Discriminator loaded... %s%sdiscriminator.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))

    def load_encoder(self, dir, tag=''):
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))
コード例 #2
0
class Agent():
    def __init__(self, algo, optimizer, env, num_actions, memory_size=10000):
        self.algo = algo  # currently not used. Future work to make learning algorithms modular.
        self.env = env
        self.policy = Feedforward(env.observation_space.shape[0],
                                  env.action_space.n)

        # Reusing ReplayMemory class for Q-learning, but this may be a bit confusing so let me clarify (I'm just lazy here)
        # As policy gradient is on-policy method, you can not use experiences generated from a different policy
        # I'm flushing the memory at the beggining of every epoch (see train function) so the policy is updated
        # only on the trajectories from the current policy
        self.memory = ReplayMemory(memory_size)

        self.n_actions = env.action_space.n
        if optimizer == 'Adam':
            self.optim = optim.Adam(self.policy.parameters())
        else:
            raise NotImplementedError

    def update_policy(self, gamma):
        memory = self.memory.get_memory()
        episode_length = len(memory)
        memory = list(zip(*memory))
        discounted_rewards = []
        s = memory[0]
        a = memory[1]
        ns = memory[2]
        r = memory[3]
        a_one_hot = torch.nn.functional.one_hot(torch.tensor(a),
                                                num_classes=2).float()

        for t in range(episode_length):
            r_forward = r[t:]
            G = 0
            # Compute a discounted cummulative reward from time step t
            for i, reward in enumerate(r_forward):
                G += gamma**i * reward
            discounted_rewards.append(G)

        rewards_t = torch.tensor(discounted_rewards).detach()
        s_t = torch.tensor(s).float()
        selected_action_probs = self.policy(s_t)
        prob = torch.sum(selected_action_probs * a_one_hot, axis=1)
        # A small hack to prevent inf when log(0)
        clipped = torch.clamp(prob, min=1e-10, max=1.0)
        J = -torch.log(clipped) * rewards_t
        grad = torch.sum(J)

        self.optim.zero_grad()
        grad.backward()
        self.optim.step()

    def train(self, num_epochs, gamma):
        # initialization

        for e in range(num_epochs):
            self.memory.flush()
            state = self.env.reset()
            done = False
            total_reward = 0
            self.env.render()
            t = 0

            while not done:
                action_prob = self.policy(torch.from_numpy(state).float())
                #print(action_prob.detach().numpy())
                #action = torch.argmax(action_prob).item()
                action = np.random.choice(range(self.n_actions),
                                          p=action_prob.detach().numpy())
                next_state, reward, done, _ = self.env.step(action)
                self.memory.push(state, action, next_state, reward)
                self.env.render()

                total_reward += reward
                state = next_state
                t += 1

            print("Episode: ", e)
            print("Reward: ", total_reward)
            writer.add_scalar('./runs/rewards', total_reward, e)

            self.update_policy(gamma)
        self.env.close()
        writer.close()