Exemple #1
0
class GAIL:
    def __init__(self, vail_sample, reward_shift, reward_aug, gae_norm,
                 global_norm, actor_lr, critic_lr, disc_lr, actor_units,
                 critic_units, disc_units, disc_reduce_units, gamma, lambd,
                 clip, entropy, epochs, batch_size, update_rate, data_dir,
                 demo_list):
        # build network
        self.actor = Actor(lr=actor_lr, hidden_units=actor_units)
        self.critic = Critic(lr=critic_lr, hidden_units=critic_units)
        self.discriminator = Discriminator(lr=disc_lr,
                                           hidden_units=disc_units,
                                           reduce_units=disc_reduce_units)
        self.encoder = VAE_Encoder(latent_num=64)

        # set hyperparameters
        self.vail_sample = vail_sample
        self.reward_shift = reward_shift
        self.reward_aug = reward_aug
        self.gae_norm = gae_norm
        self.gamma = gamma
        self.lambd = lambd
        self.gam_lam = gamma * lambd
        self.clip = clip
        self.entropy = entropy
        self.epochs = epochs
        self.batch_size = batch_size
        self.half_batch_size = batch_size // 2
        self.update_rate = update_rate
        self.grad_global_norm = global_norm
        self.beta = BETA_INIT

        # build memory
        self.memory = HorizonMemory(use_reward=reward_aug)
        self.replay = ReplayMemory()

        # build expert demonstration Pipeline
        self.data_dir = data_dir
        self.demo_list = os.listdir(data_dir)
        self.demo_group_num = 500
        self.demo_rotate = 5
        assert len(demo_list) >= self.demo_group_num
        self.set_demo()

        # ready
        self.dummy_forward()
        self.actor_vars = self.actor.trainable_variables + self.encoder.trainable_variables
        self.critic_vars = self.critic.trainable_variables + self.encoder.trainable_variables
        self.disc_vars = self.discriminator.trainable_variables + self.encoder.trainable_variables

    def dummy_forward(self):
        # connect networks
        dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32)
        dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32)
        self.encoder(dummy_state)
        self.actor(self.encoder, dummy_state)
        self.critic(self.encoder, dummy_state)
        self.discriminator(self.encoder, dummy_state, dummy_action)

    def set_demo(self):
        self.demo_list = os.listdir(data_dir)
        selected_demos = random.sample(self.demo_list, self.demo_group_num)

        expert_states = []
        expert_actions = []
        for demo_name in selected_demos:
            demo = np.load(self.data_dir + demo_name)
            states = demo['state']
            actions = demo['action']

            expert_states.append(states)
            expert_actions.append(actions)
        self.expert_states = np.concatenate(expert_states, axis=0)
        self.expert_actions = np.concatenate(expert_actions, axis=0)
        del demo

    def get_demonstration(self, sample_num):
        index = np.arange(len(self.expert_states))
        try:
            assert len(self.expert_states) >= sample_num
        except Exception:
            self.set_demo()
        np.random.shuffle(index)
        index = index[:sample_num]
        return self.expert_states[index], self.expert_actions[index]

    def memory_process(self, next_state, done):
        # [[(1,64,64,3)], [], ...], [[(1,2),(1,9),(1,3),(1,4)], [], ...], [[c_pi, d_pi, s_pi, a_pi], [], ...]
        if self.reward_aug:
            states, actions, log_old_pis, rewards = self.memory.rollout()
        else:
            states, actions, log_old_pis = self.memory.rollout()
        np_states = np.concatenate(states + [next_state], axis=0)
        np_actions = np.concatenate(actions, axis=0)

        np_rewards = self.get_reward(np_states[:-1], np_actions)  # (N, 1)
        if self.reward_aug:
            np_env_rewards = np.stack(rewards, axis=0).reshape(-1, 1)
            np_rewards = np_rewards + np_env_rewards
        gae, oracle = self.get_gae_oracle(np_states, np_rewards,
                                          done)  # (N, 1), (N, 1)
        self.replay.append(states, actions, log_old_pis, gae, oracle)
        self.memory.flush()
        if len(self.replay) >= self.update_rate:
            self.update()
            self.replay.flush()

    def get_action(self, state):
        policy = self.actor(self.encoder, state).numpy()[0]
        action = np.random.choice(ACTION_NUM, p=policy)
        # action = np.argmax(policy)
        action_one_hot = np.eye(ACTION_NUM,
                                dtype=np.float32)[[action]]  # (1, 4)
        log_old_pi = [[np.log(policy[action] + 1e-8)]]  # (1, 1)
        return action, action_one_hot, log_old_pi, policy

    def get_reward(self, states, actions):
        d = self.discriminator(self.encoder, states, actions).numpy()  # (N, 1)
        # rewards = 0.5 - d       # linear reward
        # rewards = np.tan(0.5 - d)     # tan reward
        if self.reward_shift:
            rewards = -np.log(2.0 * d + 1e-8)  # log equil reward
        else:
            rewards = -np.log(d + 1e-8)  # log reward
        # rewards = 0.1 * np.where(rewards>1, 1, rewards)
        return rewards

    def get_gae_oracle(self, states, rewards, done):
        # states include next state
        values = self.critic(self.encoder, states).numpy()  # (N+1, 1)
        if done:
            values[-1] = np.float32([0])
        N = len(rewards)
        gae = 0
        gaes = np.zeros((N, 1), dtype=np.float32)
        oracles = np.zeros((N, 1), dtype=np.float32)
        for t in reversed(range(N)):
            oracles[t] = rewards[t] + self.gamma * values[t + 1]
            delta = oracles[t] - values[t]
            gae = delta + self.gam_lam * gae
            gaes[t][0] = gae

        # oracles = gaes + values[:-1]        # (N, 1)
        if self.gae_norm:
            gaes = (gaes - np.mean(gaes)) / (np.std(gaes) + 1e-8)
        return gaes, oracles

    def update(self):
        # load & calculate data
        states, actions, log_old_pis, gaes, oracles \
            = self.replay.rollout()

        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        log_old_pis = np.concatenate(log_old_pis, axis=0)
        gaes = np.concatenate(gaes, axis=0)
        oracles = np.concatenate(oracles, axis=0)
        N = len(states)
        # update discriminator
        # load expert demonstration
        s_e, a_e = self.get_demonstration(N)

        batch_num = N // self.half_batch_size
        index = np.arange(N)
        np.random.shuffle(index)
        for i in range(batch_num):
            idx = index[i * self.half_batch_size:(i + 1) *
                        self.half_batch_size]
            s_concat = np.concatenate([states[idx], s_e[idx]], axis=0)
            a_concat = np.concatenate([actions[idx], a_e[idx]], axis=0)

            with tf.GradientTape(persistent=True) as tape:
                mu, std, sampled = self.discriminator.encode(
                    self.encoder, s_concat, a_concat)

                discs = self.discriminator.decode(
                    sampled if self.vail_sample else mu)
                kld_loss = tf.reduce_mean(tf_gaussian_KL(mu, 0, std, 1))
                agent_loss = -tf.reduce_mean(
                    tf.math.log(discs[:self.half_batch_size] + 1e-8))
                expert_loss = -tf.reduce_mean(
                    tf.math.log(1 + 1e-8 - discs[self.half_batch_size:]))
                disc_loss = agent_loss + expert_loss
                discriminator_loss = disc_loss + self.beta * kld_loss
            disc_grads = tape.gradient(discriminator_loss, self.disc_vars)
            if self.grad_global_norm > 0:
                disc_grads, _ = tf.clip_by_global_norm(disc_grads,
                                                       self.grad_global_norm)
            self.discriminator.opt.apply_gradients(
                zip(disc_grads, self.disc_vars))
            del tape

        # TODO: update posterior
        # L1 loss = logQ(code|s,prev_a,prev_code)
        # update actor & critic
        # batch_num = math.ceil(len(states) / self.batch_size)
        batch_num = len(gaes) // self.batch_size
        index = np.arange(len(gaes))
        for _ in range(self.epochs):
            np.random.shuffle(index)
            for i in range(batch_num):
                # if i == batch_num - 1:
                #     idx = index[i*self.batch_size : ]
                # else:
                idx = index[i * self.batch_size:(i + 1) * self.batch_size]
                state = states[idx]
                action = actions[idx]
                log_old_pi = log_old_pis[idx]
                gae = gaes[idx]
                oracle = oracles[idx]

                # update critic
                with tf.GradientTape(persistent=True) as tape:
                    values = self.critic(self.encoder, state)  # (N, 1)
                    critic_loss = tf.reduce_mean(
                        (oracle - values)**2)  # MSE loss
                critic_grads = tape.gradient(critic_loss, self.critic_vars)
                if self.grad_global_norm > 0:
                    critic_grads, _ = tf.clip_by_global_norm(
                        critic_grads, self.grad_global_norm)
                self.critic.opt.apply_gradients(
                    zip(critic_grads, self.critic_vars))
                del tape

                # update actor
                with tf.GradientTape(persistent=True) as tape:
                    pred_action = self.actor(self.encoder, state)

                    # RL (PPO) term
                    log_pi = tf.expand_dims(tf.math.log(
                        tf.reduce_sum(pred_action * action, axis=1) + 1e-8),
                                            axis=1)  # (N, 1)
                    ratio = tf.exp(log_pi - log_old_pi)
                    clip_ratio = tf.clip_by_value(ratio, 1 - self.clip,
                                                  1 + self.clip)
                    clip_loss = -tf.reduce_mean(
                        tf.minimum(ratio * gae, clip_ratio * gae))
                    entropy = tf.reduce_mean(tf.exp(log_pi) * log_pi)
                    actor_loss = clip_loss + self.entropy * entropy

                actor_grads = tape.gradient(
                    actor_loss, self.actor_vars)  # NOTE: freeze posterior
                if self.grad_global_norm > 0:
                    actor_grads, _ = tf.clip_by_global_norm(
                        actor_grads, self.grad_global_norm)
                self.actor.opt.apply_gradients(
                    zip(actor_grads, self.actor_vars))

                del tape
            # print('%d samples trained... D loss: %.4f C loss: %.4f A loss: %.4f\t\t\t'
            #     % (len(gaes), disc_loss, critic_loss, actor_loss), end='\r')

    def save_model(self, dir, tag=''):
        self.actor.save_weights(dir + tag + 'actor.h5')
        self.critic.save_weights(dir + tag + 'critic.h5')
        self.discriminator.save_weights(dir + tag + 'discriminator.h5')
        self.encoder.save_weights(dir + tag + 'encoder.h5')

    def load_model(self, dir, tag=''):
        if os.path.exists(dir + tag + 'actor.h5'):
            self.actor.load_weights(dir + tag + 'actor.h5')
            print('Actor loaded... %s%sactor.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'critic.h5'):
            self.critic.load_weights(dir + tag + 'critic.h5')
            print('Critic loaded... %s%scritic.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'discriminator.h5'):
            self.discriminator.load_weights(dir + tag + 'discriminator.h5')
            print('Discriminator loaded... %s%sdiscriminator.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))

    def load_encoder(self, dir, tag=''):
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))
Exemple #2
0
            eval_net.eval()
            values = eval_net(*state)
            # print(values)
            # print(state)
            if random.random() > epsilon:
                prob_uniform = (values > -9999999).float()
                dist = Categorical(prob_uniform)
                action = dist.sample()
            else:
                action = torch.argmax(values, dim=1)
            next_state, reward, done = envs.step(action.cpu().numpy())
            for idx in range(state[0].size(0)):
                memory.append([state[0][[idx]].cpu(),
                               state[1][[idx]].cpu(),
                               action[[idx]].cpu(),
                               reward[[idx]].cpu(),
                               next_state[0][[idx]].cpu(),
                               next_state[1][[idx]].cpu(),
                               done[[idx]].cpu()]
                              )
            state = next_state

            if len(memory) >= memory.capacity:
                learn(memory, eval_net, target_net, learn_step_counter, args.double_dqn)
                learn_step_counter += 1

            if _i % 50 == 0:
                ret = evaluate_batch(eval_net, 0, 100)
                ret = ret[0]
                performance.append(ret)
                time_now = time.time()
                print('average performance on{} {}{} {}: {:.4f}, time: {:.4f}, step: {}'.format(args.fn, args.func_type,args.package_num, _i, performance[-1], time_now - time_last, learn_step_counter//q_network_iteration))