예제 #1
0
class GAIL:
    def __init__(self, reward_shift, actor_units, critic_units, disc_units,
                 disc_reduce_units, code_units):
        # build network
        self.actor = Actor(lr=0, hidden_units=actor_units)
        self.critic = Critic(lr=0, hidden_units=critic_units)
        self.discriminator = Discriminator(lr=0,
                                           hidden_units=disc_units,
                                           reduce_units=disc_reduce_units)
        self.encoder = VAE_Encoder(latent_num=64)
        self.prior = DiscretePosterior(lr=0, hidden_units=code_units)

        # set hyperparameters
        self.reward_shift = reward_shift
        self.memory = HorizonMemory()

        # ready
        self.dummy_forward()

    def dummy_forward(self):
        # connect networks
        dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32)
        dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32)
        dummy_code = np.zeros([1, DISC_CODE_NUM], dtype=np.float32)
        self.encoder(dummy_state)
        self.prior(self.encoder, dummy_state, dummy_action, dummy_code)
        self.actor(self.encoder, dummy_state, dummy_code)
        self.critic(self.encoder, dummy_state)
        self.discriminator(self.encoder, dummy_state, dummy_action)

    def get_code(self, state, prev_action, prev_code):
        code_prob = self.prior(self.encoder, state, prev_action,
                               prev_code).numpy()[0]
        code_idx = np.argmax(code_prob)  # greedy
        # code_idx = np.random.choice(DISC_CODE_NUM, p=code_prob)
        code = np.eye(DISC_CODE_NUM, dtype=np.float32)[[code_idx]]  # (1, C)
        return code_idx, code, code_prob

    def get_action(self, state, code):
        policy = self.actor(self.encoder, state, code).numpy()[0]
        # action = np.random.choice(ACTION_NUM, p=policy)
        action = np.argmax(policy)  # greedy
        action_one_hot = np.eye(ACTION_NUM,
                                dtype=np.float32)[[action]]  # (1, 4)
        log_old_pi = [[np.log(policy[action] + 1e-8)]]  # (1, 1)
        return action, action_one_hot, log_old_pi, policy

    def get_reward(self, states, actions):
        d = self.discriminator(self.encoder, states, actions).numpy()  # (N, 1)
        # rewards = 0.5 - d       # linear reward
        # rewards = np.tan(0.5 - d)     # tan reward
        if self.reward_shift:
            rewards = -np.log(2.0 * d + 1e-8)  # log equil reward
        else:
            rewards = -np.log(d + 1e-8)  # log reward
        # rewards = 0.1 * np.where(rewards>1, 1, rewards)
        return rewards

    def load_model(self, dir, tag=''):
        if os.path.exists(dir + tag + 'actor.h5'):
            self.actor.load_weights(dir + tag + 'actor.h5')
            print('Actor loaded... %s%sactor.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'critic.h5'):
            self.critic.load_weights(dir + tag + 'critic.h5')
            print('Critic loaded... %s%scritic.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'discriminator.h5'):
            self.discriminator.load_weights(dir + tag + 'discriminator.h5')
            print('Discriminator loaded... %s%sdiscriminator.h5' % (dir, tag))

    def load_encoder(self, dir, tag=''):
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))

    def load_prior(self, dir, tag=''):
        if os.path.exists(dir + tag + 'prior.h5'):
            self.prior.load_weights(dir + tag + 'prior.h5')
            print('prior loaded... %s%sprior.h5' % (dir, tag))
예제 #2
0
class CodeVAE:
    def __init__(self, global_norm, lr, actor_units, code_units, epochs,
                 batch_size, data_dir, demo_list):
        # build network
        self.actor = Actor(lr=lr, hidden_units=actor_units)
        self.prior = DiscretePosterior(lr=lr, hidden_units=code_units)
        self.encoder = VAE_Encoder(latent_num=64)
        self.opt = tf.keras.optimizers.Adam(learning_rate=lr)

        # set hyperparameters
        self.epochs = epochs
        self.batch_size = batch_size
        self.grad_global_norm = global_norm
        self.init_temperature = 2.0
        self.temperature = self.init_temperature
        self.min_temperature = 0.5
        self.temp_decay = 1e-3
        self.beta = 1e-4

        # build expert demonstration Pipeline
        self.data_dir = data_dir
        self.demo_list = os.listdir(data_dir)
        self.demo_group_num = 500
        self.demo_rotate = 3
        assert len(demo_list) >= self.demo_group_num
        self.set_demo()
        self.total = 0
        # ready
        self.dummy_forward()
        self.vars = self.actor.trainable_variables + self.prior.trainable_variables

    def dummy_forward(self):
        # connect networks
        dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32)
        dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32)
        dummy_code = np.zeros([1] + [DISC_CODE_NUM], dtype=np.float32)
        self.encoder(dummy_state)
        self.prior(self.encoder, dummy_state, dummy_action, dummy_code)
        self.actor(self.encoder, dummy_state, dummy_code)

    def set_demo(self):
        self.demo_list = os.listdir(data_dir)
        selected_demos = random.sample(self.demo_list, self.demo_group_num)

        self.expert_states = []
        self.expert_actions = []
        for demo_name in selected_demos:
            demo = np.load(self.data_dir + demo_name)
            states = demo['state']
            actions = demo['action']

            self.expert_states.append(states)
            self.expert_actions.append(actions)
        # self.expert_states = np.concatenate(expert_states, axis=0)
        # self.expert_actions = np.concatenate(expert_actions, axis=0)
        del demo

    def update_temperature(self, epoch):
        self.temperature = \
            max(self.min_temperature, self.init_temperature * math.exp(-self.temp_decay * epoch))

    def label_prev_code(self, s, a):
        # sequential labeling
        prev_codes = []
        running_code = np.eye(DISC_CODE_NUM, dtype=np.float32)[[
            np.random.randint(0, DISC_CODE_NUM)
        ]]  # initial code
        # c_0 ~ c_t-1   [N-1, C]
        for t in range(1, len(s)):
            # s_t a_t-1 c_t-1 -> c_t
            prev_codes.append(running_code)
            running_code = self.prior(self.encoder, s[t:t + 1], a[t - 1:t],
                                      running_code).numpy()
            running_code = np.eye(DISC_CODE_NUM, dtype=np.float32)[[
                np.random.choice(DISC_CODE_NUM, p=running_code[0])
            ]]

        return np.concatenate(prev_codes, axis=0)

    def update(self):
        # load expert demonstration
        # states, prev_actions = self.get_demonstration()
        # about 20000 samples
        states = []
        actions = []
        prev_actions = []
        prev_codes = []
        for s, a in zip(self.expert_states, self.expert_actions):
            states.append(s[1:])  # s_1: s_t
            prev_actions.append(a[:-1])  # a_0 : a_t-1
            actions.append(a[1:])  # a_1 : a_t
            prev_code = self.label_prev_code(s, a)  # c_0 : c_t-1
            prev_codes.append(prev_code)
        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        prev_actions = np.concatenate(prev_actions, axis=0)
        prev_codes = np.concatenate(prev_codes, axis=0)
        # print(prev_codes)
        batch_num = len(states) // self.batch_size
        index = np.arange(len(states))
        loss = 0
        for epoch in range(self.epochs):
            np.random.shuffle(index)
            for i in range(batch_num):
                idx = index[i * self.batch_size:(i + 1) * self.batch_size]
                state = states[idx]  # (N, S) s_t
                action = actions[idx]  # (N, A) a_t
                prev_action = prev_actions[idx]  # (N, A) a_t-1
                prev_code = prev_codes[idx]  # (N, C) c_t-1

                # update vae
                with tf.GradientTape() as tape:
                    code = self.prior(self.encoder, state, prev_action,
                                      prev_code)  # (N, C) c_t
                    sampled_code = tf_reparameterize(code, self.temperature)
                    policy = self.actor(self.encoder, state,
                                        sampled_code)  # (N, A) a_t
                    log_probs = tf.math.log(sampled_code + 1e-8)
                    log_prior_probs = tf.math.log(1 / DISC_CODE_NUM)
                    kld_loss = tf.reduce_mean(
                        tf.reduce_sum(sampled_code *
                                      (log_probs - log_prior_probs),
                                      axis=1))
                    actor_loss = -tf.reduce_mean(
                        tf.reduce_sum(action * tf.math.log(policy + 1e-8),
                                      axis=1))  # (N-1, )

                    vae_loss = self.beta * kld_loss + actor_loss
                    print(
                        ('{:.2f} ' * 4).format(*policy.numpy()[100]) + ' / ' +
                        ('{:.2f} ' * 4).format(*code.numpy()[100]) + ' / ' +
                        ('{:.2f} ' * 4).format(*sampled_code.numpy()[100]),
                        '%.2f %.2f %.2f %.2f' %
                        (vae_loss.numpy(), kld_loss.numpy(),
                         actor_loss.numpy(), self.temperature),
                        end='\r')
                grads = tape.gradient(vae_loss, self.vars)
                if self.grad_global_norm > 0:
                    grads, _ = tf.clip_by_global_norm(grads,
                                                      self.grad_global_norm)
                self.opt.apply_gradients(zip(grads, self.vars))
                loss += vae_loss.numpy()
                self.total += 1
                self.update_temperature(self.total)

        loss /= self.epochs * batch_num
        return loss

    def save_model(self, dir, tag=''):
        self.actor.save_weights(dir + tag + 'actor.h5')
        self.prior.save_weights(dir + tag + 'prior.h5')

    def load_model(self, dir, tag=''):
        if os.path.exists(dir + tag + 'actor.h5'):
            self.actor.load_weights(dir + tag + 'actor.h5')
            print('Actor loaded... %s%sactor.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'prior.h5'):
            self.prior.load_weights(dir + tag + 'prior.h5')
            print('prior loaded... %s%sprior.h5' % (dir, tag))

    def load_encoder(self, dir, tag=''):
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('Encoder loaded... %s%sencoder.h5' % (dir, tag))
예제 #3
0
class GAIL:
    def __init__(self, vail_sample, reward_shift, reward_aug, gae_norm,
                 global_norm, actor_lr, critic_lr, disc_lr, actor_units,
                 critic_units, disc_units, disc_reduce_units, gamma, lambd,
                 clip, entropy, epochs, batch_size, update_rate, data_dir,
                 demo_list):
        # build network
        self.actor = Actor(lr=actor_lr, hidden_units=actor_units)
        self.critic = Critic(lr=critic_lr, hidden_units=critic_units)
        self.discriminator = Discriminator(lr=disc_lr,
                                           hidden_units=disc_units,
                                           reduce_units=disc_reduce_units)
        self.encoder = VAE_Encoder(latent_num=64)

        # set hyperparameters
        self.vail_sample = vail_sample
        self.reward_shift = reward_shift
        self.reward_aug = reward_aug
        self.gae_norm = gae_norm
        self.gamma = gamma
        self.lambd = lambd
        self.gam_lam = gamma * lambd
        self.clip = clip
        self.entropy = entropy
        self.epochs = epochs
        self.batch_size = batch_size
        self.half_batch_size = batch_size // 2
        self.update_rate = update_rate
        self.grad_global_norm = global_norm
        self.beta = BETA_INIT

        # build memory
        self.memory = HorizonMemory(use_reward=reward_aug)
        self.replay = ReplayMemory()

        # build expert demonstration Pipeline
        self.data_dir = data_dir
        self.demo_list = os.listdir(data_dir)
        self.demo_group_num = 500
        self.demo_rotate = 5
        assert len(demo_list) >= self.demo_group_num
        self.set_demo()

        # ready
        self.dummy_forward()
        self.actor_vars = self.actor.trainable_variables + self.encoder.trainable_variables
        self.critic_vars = self.critic.trainable_variables + self.encoder.trainable_variables
        self.disc_vars = self.discriminator.trainable_variables + self.encoder.trainable_variables

    def dummy_forward(self):
        # connect networks
        dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32)
        dummy_action = np.zeros([1] + ACTION_SHAPE, dtype=np.float32)
        self.encoder(dummy_state)
        self.actor(self.encoder, dummy_state)
        self.critic(self.encoder, dummy_state)
        self.discriminator(self.encoder, dummy_state, dummy_action)

    def set_demo(self):
        self.demo_list = os.listdir(data_dir)
        selected_demos = random.sample(self.demo_list, self.demo_group_num)

        expert_states = []
        expert_actions = []
        for demo_name in selected_demos:
            demo = np.load(self.data_dir + demo_name)
            states = demo['state']
            actions = demo['action']

            expert_states.append(states)
            expert_actions.append(actions)
        self.expert_states = np.concatenate(expert_states, axis=0)
        self.expert_actions = np.concatenate(expert_actions, axis=0)
        del demo

    def get_demonstration(self, sample_num):
        index = np.arange(len(self.expert_states))
        try:
            assert len(self.expert_states) >= sample_num
        except Exception:
            self.set_demo()
        np.random.shuffle(index)
        index = index[:sample_num]
        return self.expert_states[index], self.expert_actions[index]

    def memory_process(self, next_state, done):
        # [[(1,64,64,3)], [], ...], [[(1,2),(1,9),(1,3),(1,4)], [], ...], [[c_pi, d_pi, s_pi, a_pi], [], ...]
        if self.reward_aug:
            states, actions, log_old_pis, rewards = self.memory.rollout()
        else:
            states, actions, log_old_pis = self.memory.rollout()
        np_states = np.concatenate(states + [next_state], axis=0)
        np_actions = np.concatenate(actions, axis=0)

        np_rewards = self.get_reward(np_states[:-1], np_actions)  # (N, 1)
        if self.reward_aug:
            np_env_rewards = np.stack(rewards, axis=0).reshape(-1, 1)
            np_rewards = np_rewards + np_env_rewards
        gae, oracle = self.get_gae_oracle(np_states, np_rewards,
                                          done)  # (N, 1), (N, 1)
        self.replay.append(states, actions, log_old_pis, gae, oracle)
        self.memory.flush()
        if len(self.replay) >= self.update_rate:
            self.update()
            self.replay.flush()

    def get_action(self, state):
        policy = self.actor(self.encoder, state).numpy()[0]
        action = np.random.choice(ACTION_NUM, p=policy)
        # action = np.argmax(policy)
        action_one_hot = np.eye(ACTION_NUM,
                                dtype=np.float32)[[action]]  # (1, 4)
        log_old_pi = [[np.log(policy[action] + 1e-8)]]  # (1, 1)
        return action, action_one_hot, log_old_pi, policy

    def get_reward(self, states, actions):
        d = self.discriminator(self.encoder, states, actions).numpy()  # (N, 1)
        # rewards = 0.5 - d       # linear reward
        # rewards = np.tan(0.5 - d)     # tan reward
        if self.reward_shift:
            rewards = -np.log(2.0 * d + 1e-8)  # log equil reward
        else:
            rewards = -np.log(d + 1e-8)  # log reward
        # rewards = 0.1 * np.where(rewards>1, 1, rewards)
        return rewards

    def get_gae_oracle(self, states, rewards, done):
        # states include next state
        values = self.critic(self.encoder, states).numpy()  # (N+1, 1)
        if done:
            values[-1] = np.float32([0])
        N = len(rewards)
        gae = 0
        gaes = np.zeros((N, 1), dtype=np.float32)
        oracles = np.zeros((N, 1), dtype=np.float32)
        for t in reversed(range(N)):
            oracles[t] = rewards[t] + self.gamma * values[t + 1]
            delta = oracles[t] - values[t]
            gae = delta + self.gam_lam * gae
            gaes[t][0] = gae

        # oracles = gaes + values[:-1]        # (N, 1)
        if self.gae_norm:
            gaes = (gaes - np.mean(gaes)) / (np.std(gaes) + 1e-8)
        return gaes, oracles

    def update(self):
        # load & calculate data
        states, actions, log_old_pis, gaes, oracles \
            = self.replay.rollout()

        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        log_old_pis = np.concatenate(log_old_pis, axis=0)
        gaes = np.concatenate(gaes, axis=0)
        oracles = np.concatenate(oracles, axis=0)
        N = len(states)
        # update discriminator
        # load expert demonstration
        s_e, a_e = self.get_demonstration(N)

        batch_num = N // self.half_batch_size
        index = np.arange(N)
        np.random.shuffle(index)
        for i in range(batch_num):
            idx = index[i * self.half_batch_size:(i + 1) *
                        self.half_batch_size]
            s_concat = np.concatenate([states[idx], s_e[idx]], axis=0)
            a_concat = np.concatenate([actions[idx], a_e[idx]], axis=0)

            with tf.GradientTape(persistent=True) as tape:
                mu, std, sampled = self.discriminator.encode(
                    self.encoder, s_concat, a_concat)

                discs = self.discriminator.decode(
                    sampled if self.vail_sample else mu)
                kld_loss = tf.reduce_mean(tf_gaussian_KL(mu, 0, std, 1))
                agent_loss = -tf.reduce_mean(
                    tf.math.log(discs[:self.half_batch_size] + 1e-8))
                expert_loss = -tf.reduce_mean(
                    tf.math.log(1 + 1e-8 - discs[self.half_batch_size:]))
                disc_loss = agent_loss + expert_loss
                discriminator_loss = disc_loss + self.beta * kld_loss
            disc_grads = tape.gradient(discriminator_loss, self.disc_vars)
            if self.grad_global_norm > 0:
                disc_grads, _ = tf.clip_by_global_norm(disc_grads,
                                                       self.grad_global_norm)
            self.discriminator.opt.apply_gradients(
                zip(disc_grads, self.disc_vars))
            del tape

        # TODO: update posterior
        # L1 loss = logQ(code|s,prev_a,prev_code)
        # update actor & critic
        # batch_num = math.ceil(len(states) / self.batch_size)
        batch_num = len(gaes) // self.batch_size
        index = np.arange(len(gaes))
        for _ in range(self.epochs):
            np.random.shuffle(index)
            for i in range(batch_num):
                # if i == batch_num - 1:
                #     idx = index[i*self.batch_size : ]
                # else:
                idx = index[i * self.batch_size:(i + 1) * self.batch_size]
                state = states[idx]
                action = actions[idx]
                log_old_pi = log_old_pis[idx]
                gae = gaes[idx]
                oracle = oracles[idx]

                # update critic
                with tf.GradientTape(persistent=True) as tape:
                    values = self.critic(self.encoder, state)  # (N, 1)
                    critic_loss = tf.reduce_mean(
                        (oracle - values)**2)  # MSE loss
                critic_grads = tape.gradient(critic_loss, self.critic_vars)
                if self.grad_global_norm > 0:
                    critic_grads, _ = tf.clip_by_global_norm(
                        critic_grads, self.grad_global_norm)
                self.critic.opt.apply_gradients(
                    zip(critic_grads, self.critic_vars))
                del tape

                # update actor
                with tf.GradientTape(persistent=True) as tape:
                    pred_action = self.actor(self.encoder, state)

                    # RL (PPO) term
                    log_pi = tf.expand_dims(tf.math.log(
                        tf.reduce_sum(pred_action * action, axis=1) + 1e-8),
                                            axis=1)  # (N, 1)
                    ratio = tf.exp(log_pi - log_old_pi)
                    clip_ratio = tf.clip_by_value(ratio, 1 - self.clip,
                                                  1 + self.clip)
                    clip_loss = -tf.reduce_mean(
                        tf.minimum(ratio * gae, clip_ratio * gae))
                    entropy = tf.reduce_mean(tf.exp(log_pi) * log_pi)
                    actor_loss = clip_loss + self.entropy * entropy

                actor_grads = tape.gradient(
                    actor_loss, self.actor_vars)  # NOTE: freeze posterior
                if self.grad_global_norm > 0:
                    actor_grads, _ = tf.clip_by_global_norm(
                        actor_grads, self.grad_global_norm)
                self.actor.opt.apply_gradients(
                    zip(actor_grads, self.actor_vars))

                del tape
            # print('%d samples trained... D loss: %.4f C loss: %.4f A loss: %.4f\t\t\t'
            #     % (len(gaes), disc_loss, critic_loss, actor_loss), end='\r')

    def save_model(self, dir, tag=''):
        self.actor.save_weights(dir + tag + 'actor.h5')
        self.critic.save_weights(dir + tag + 'critic.h5')
        self.discriminator.save_weights(dir + tag + 'discriminator.h5')
        self.encoder.save_weights(dir + tag + 'encoder.h5')

    def load_model(self, dir, tag=''):
        if os.path.exists(dir + tag + 'actor.h5'):
            self.actor.load_weights(dir + tag + 'actor.h5')
            print('Actor loaded... %s%sactor.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'critic.h5'):
            self.critic.load_weights(dir + tag + 'critic.h5')
            print('Critic loaded... %s%scritic.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'discriminator.h5'):
            self.discriminator.load_weights(dir + tag + 'discriminator.h5')
            print('Discriminator loaded... %s%sdiscriminator.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))

    def load_encoder(self, dir, tag=''):
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))
예제 #4
0
class BC:
    def __init__(self, global_norm, actor_lr, actor_units, 
                epochs, batch_size, data_dir, demo_list):
        # build network
        self.actor  = Actor(lr=actor_lr, hidden_units=actor_units)
        self.encoder = VAE_Encoder(latent_num=64)
        self.opt = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        # set hyperparameters
        self.epochs = epochs
        self.batch_size = batch_size
        self.grad_global_norm = global_norm

        # build expert demonstration Pipeline
        self.data_dir = data_dir
        self.demo_list = os.listdir(data_dir)
        self.demo_group_num = 1000
        self.demo_rotate = 20
        assert len(demo_list) >= self.demo_group_num
        self.set_demo()

        # ready
        self.dummy_forward()
        self.vars = self.actor.trainable_variables + self.encoder.trainable_variables

    def dummy_forward(self):
        # connect networks
        dummy_state = np.zeros([1] + STATE_SHAPE, dtype=np.float32)
        self.encoder(dummy_state)
        self.actor(self.encoder, dummy_state)

    def set_demo(self):
        self.demo_list = os.listdir(data_dir)
        selected_demos = random.sample(self.demo_list, self.demo_group_num)

        expert_states = []
        expert_actions = []
        for demo_name in selected_demos:
            demo = np.load(self.data_dir + demo_name)
            states = demo['state']
            actions = demo['action']
    
            expert_states.append(states)
            expert_actions.append(actions)
        self.expert_states = np.concatenate(expert_states, axis=0)
        self.expert_actions = np.concatenate(expert_actions, axis=0)
        del demo

    def get_demonstration(self, sample_num):
        index = np.arange(len(self.expert_states))
        try:
            assert len(self.expert_states) >= sample_num
        except Exception:
            self.set_demo()
        np.random.shuffle(index)
        index = index[:sample_num]
        return self.expert_states[index], self.expert_actions[index]

    def update(self):
        # load expert demonstration
        s_e, a_e = self.get_demonstration(self.batch_size * self.epochs)
        
        batch_num = len(s_e) // self.batch_size
        index = np.arange(len(s_e))
        np.random.shuffle(index)
        loss = 0
        for i in range(batch_num):
            idx = index[i*self.batch_size : (i+1)*self.batch_size]
            state = s_e[idx]
            action = a_e[idx]

            # update actor
            with tf.GradientTape() as tape:
                pred_action = self.actor(self.encoder, state)   # (N, A)
                # CE
                actor_loss = -tf.reduce_mean(tf.reduce_sum(action * tf.math.log(pred_action + 1e-8), axis=1))
            grads = tape.gradient(actor_loss, self.vars)
            if self.grad_global_norm > 0:
                grads, _ = tf.clip_by_global_norm(grads, self.grad_global_norm)
            self.opt.apply_gradients(zip(grads, self.vars))
            loss += actor_loss.numpy()
        return loss / batch_num
            

    def save_model(self, dir, tag=''):
        self.actor.save_weights(dir + tag + 'actor.h5')
        self.encoder.save_weights(dir + tag + 'encoder.h5')

    def load_model(self, dir, tag=''):
        if os.path.exists(dir + tag + 'actor.h5'):
            self.actor.load_weights(dir + tag + 'actor.h5')
            print('Actor loaded... %s%sactor.h5' % (dir, tag))
        if os.path.exists(dir + tag + 'encoder.h5'):
            self.encoder.load_weights(dir + tag + 'encoder.h5')
            print('encoder loaded... %s%sencoder.h5' % (dir, tag))