Esempio n. 1
0
class HIRO:
    def __init__(self,
                 env,
                 gamma=0.99,
                 polyak=0.995,
                 c=10,
                 d=2,
                 high_act_noise=0.1,
                 low_act_noise=0.1,
                 high_rew_scale=0.1,
                 low_rew_scale=1.0,
                 render=False,
                 batch_size=32,
                 q_lr=1e-3,
                 p_lr=1e-4,
                 buffer_capacity=5000,
                 max_episodes=100,
                 save_path=None,
                 load_path=None,
                 print_freq=1,
                 log_dir='logs/train',
                 training=True
                 ):
        self.gamma = gamma
        self.polyak = polyak
        self.low_act_noise = low_act_noise
        self.high_act_noise = high_act_noise
        self.low_rew_scale = low_rew_scale
        self.high_rew_scale = high_rew_scale
        self.render = render
        self.batch_size = batch_size
        self.p_lr = p_lr
        self.q_lr = q_lr
        self.max_episodes = max_episodes
        self.env = env
        self.rewards = []
        self.print_freq = print_freq
        self.save_path = save_path
        self.c = c
        self.d = d
        self.higher_buffer = ReplayBuffer(buffer_capacity, tuple_length=5)
        self.lower_buffer = ReplayBuffer(buffer_capacity, tuple_length=4)

        self.low_actor, self.low_critic_1, self.low_critic_2 = create_actor_critic(
            state_dim=2 * env.observation_space.shape[0],
            action_dim=env.action_space.shape[0],
            action_range=env.action_space.high)

        self.low_target_actor, self.low_target_critic_1, self.low_target_critic_2 = create_actor_critic(
            state_dim=2 * env.observation_space.shape[0],
            action_dim=env.action_space.shape[0],
            action_range=env.action_space.high)

        self.high_actor, self.high_critic_1, self.high_critic_2 = create_actor_critic(
            state_dim=env.observation_space.shape[0],
            action_dim=env.observation_space.shape[0],
            action_range=env.observation_space.high)

        self.high_target_actor, self.high_target_critic_1, self.high_target_critic_2 = create_actor_critic(
            state_dim=env.observation_space.shape[0],
            action_dim=env.observation_space.shape[0],
            action_range=env.observation_space.high)
        self.low_target_actor.set_weights(self.low_actor.get_weights())
        self.low_target_critic_1.set_weights(self.low_critic_1.get_weights())
        self.low_target_critic_2.set_weights(self.low_critic_2.get_weights())
        self.high_target_actor.set_weights(self.high_actor.get_weights())
        self.high_target_critic_1.set_weights(self.high_critic_1.get_weights())
        self.high_target_critic_2.set_weights(self.high_critic_2.get_weights())

        if training:
            self.low_actor_optimizer = tf.keras.optimizers.Adam(learning_rate=self.p_lr)
            self.low_critic_1_optimizer = tf.keras.optimizers.Adam(learning_rate=self.q_lr)
            self.low_critic_2_optimizer = tf.keras.optimizers.Adam(learning_rate=self.q_lr)
            self.high_actor_optimizer = tf.keras.optimizers.Adam(learning_rate=self.p_lr)
            self.high_critic_1_optimizer = tf.keras.optimizers.Adam(learning_rate=self.q_lr)
            self.high_critic_2_optimizer = tf.keras.optimizers.Adam(learning_rate=self.q_lr)
            self.mse = tf.keras.losses.MeanSquaredError()
            self.summary_writer = tf.summary.create_file_writer(log_dir)

            self.low_actor_train_fn = self.create_train_step_actor_fn(self.low_actor, self.low_critic_1,
                                                                      self.low_actor_optimizer)
            self.low_critic_train_fns = [self.create_train_step_critic_fn(critic=c, optimizer=o) for c, o in
                                         [(self.low_critic_1, self.low_critic_1_optimizer),
                                          (self.low_critic_2, self.low_critic_2_optimizer)]]

            self.high_actor_train_fn = self.create_train_step_actor_fn(self.high_actor, self.high_critic_1,
                                                                       self.high_actor_optimizer)
            self.high_critic_train_fns = [self.create_train_step_critic_fn(critic=c, optimizer=o) for c, o in
                                          [(self.high_critic_1, self.high_critic_1_optimizer),
                                           (self.high_critic_2, self.high_critic_2_optimizer)]]
        if load_path is not None:
            self.low_actor.load_weights(f'{load_path}/low/actor')
            self.low_critic_1.load_weights(f'{load_path}/low/critic_1')
            self.low_critic_2.load_weights(f'{load_path}/low/critic_2')
            self.high_actor.load_weights(f'{load_path}/high/actor')
            self.high_critic_1.load_weights(f'{load_path}/high/critic_1')
            self.high_critic_2.load_weights(f'{load_path}/high/critic_2')

    @staticmethod
    def goal_transition(state, goal, next_state):
        return state + goal - next_state

    @staticmethod
    def intrinsic_reward(state, goal, next_state):
        return - np.linalg.norm(state + goal - next_state)

    def act(self, obs, goal, noise=False):
        norm_dist = tf.random.normal(self.env.action_space.shape, stddev=0.1 * self.env.action_space.high)
        action = self.low_actor(np.concatenate((obs, goal), axis=1)).numpy()
        action = np.clip(action + (norm_dist.numpy() if noise else 0),
                         a_min=self.env.action_space.low,
                         a_max=self.env.action_space.high)
        return action

    def get_goal(self, obs, noise=False):
        norm_dist = tf.random.normal(self.env.observation_space.shape, stddev=0.1 * self.env.observation_space.high)
        action = self.high_actor(obs).numpy()
        action = np.clip(action + (norm_dist.numpy() if noise else 0),
                         a_min=self.env.observation_space.low,
                         a_max=self.env.observation_space.high)
        return action

    @tf.function
    def log_probability(self, states, actions, candidate_goal):
        goals = tf.reshape(candidate_goal, (1, -1))

        def body(curr_i, curr_goals, s):
            new_goals = tf.concat(
                (curr_goals,
                 tf.reshape(self.goal_transition(s[curr_i - 1], curr_goals[curr_i - 1], s[curr_i]), (1, -1))), axis=0)
            curr_i += 1
            return [curr_i, new_goals, s]

        def condition(curr_i, curr_goals, s):
            return curr_i < s.shape[0] and not (
                    tf.equal(tf.math.count_nonzero(s[curr_i]), 0) and tf.equal(tf.math.count_nonzero(actions[curr_i]),
                                                                               0))

        # If a state-action pair is all zero, then the episode ended before an entire sequence of length c was recorded.
        # We must remove these empty states and actions from the log probability calculation, as they could skew the
        #   argmax computation
        i = tf.constant(1)
        i, goals, states = tf.while_loop(condition, body, [i, goals, states],
                                         shape_invariants=[tf.TensorShape(None), tf.TensorShape([None, goals.shape[1]]),
                                                           states.shape])
        states = states[:i, :]
        actions = actions[:i, :]

        action_predictions = self.low_actor(tf.concat((states, goals), axis=1))
        return -(1 / 2) * tf.reduce_sum(tf.linalg.norm(actions - action_predictions, axis=1))

    @tf.function
    def off_policy_correct(self, states, goals, actions, new_states):
        first_states = tf.reshape(states, (self.batch_size, -1))[:, :new_states[0].shape[0]]
        means = new_states - first_states
        std_dev = 0.5 * (1 / 2) * tf.convert_to_tensor(self.env.observation_space.high)

        for i in range(states.shape[0]):
            # Sample eight candidate goals sampled randomly from a Gaussian centered at s_{t+c} - s_t
            # Include the original goal and a goal corresponding to the difference s_{t+c} - s_t
            # TODO: clip the random actions to lie within the high-level action range
            candidate_goals = tf.concat(
                (tf.random.normal(shape=(8, self.env.observation_space.shape[0]), mean=means[i], stddev=std_dev),
                 tf.reshape(goals[i], (1, -1)), tf.reshape(means[i], (1, -1))),
                axis=0)

            chosen_goal = tf.argmax(
                [self.log_probability(states[i], actions[i], candidate_goals[g]) for g in
                 range(candidate_goals.shape[0])])
            goals = tf.tensor_scatter_nd_update(goals, [[i]], [candidate_goals[chosen_goal]])

        return first_states, goals

    @tf.function
    def train_step_critics(self, states, actions, rewards, next_states, actor, target_critic_1,
                           target_critic_2, critic_trains_fns, target_noise,
                           scope='Policy'):
        target_goal_preds = actor(next_states)
        target_goal_preds += target_noise

        target_q_values_1 = target_critic_1([next_states, target_goal_preds])
        target_q_values_2 = target_critic_2([next_states, target_goal_preds])

        target_q_values = tf.concat((target_q_values_1, target_q_values_2), axis=1)
        target_q_values = tf.reshape(tf.reduce_min(target_q_values, axis=1), (self.batch_size, -1))
        targets = rewards + self.gamma * target_q_values

        critic_trains_fns[0](states, actions, targets, scope=scope, label='Critic 1')
        critic_trains_fns[1](states, actions, targets, scope=scope, label='Critic 2')

    def create_train_step_actor_fn(self, actor, critic, optimizer):
        @tf.function
        def train_step_actor(states, scope='policy', label='actor'):
            with tf.GradientTape() as tape:
                action_predictions = actor(states)
                q_values = critic([states, action_predictions])
                policy_loss = -tf.reduce_mean(q_values)
            gradients = tape.gradient(policy_loss, actor.trainable_variables)
            optimizer.apply_gradients(zip(gradients, actor.trainable_variables))

            with tf.name_scope(scope):
                with self.summary_writer.as_default():
                    tf.summary.scalar(f'{label} Policy Loss', policy_loss, step=optimizer.iterations)

        return train_step_actor

    def create_train_step_critic_fn(self, critic, optimizer):
        @tf.function
        def train_step_critic(states, actions, targets, scope='Policy', label='Critic'):
            with tf.GradientTape() as tape:
                q_values = critic([states, actions])
                mse_loss = self.mse(q_values, targets)
            gradients = tape.gradient(mse_loss, critic.trainable_variables)
            optimizer.apply_gradients(zip(gradients, critic.trainable_variables))

            with tf.name_scope(scope):
                with self.summary_writer.as_default():
                    tf.summary.scalar(f'{label} MSE Loss', mse_loss, step=optimizer.iterations)
                    tf.summary.scalar(f'{label} Mean Q Values', tf.reduce_mean(q_values), step=optimizer.iterations)

        return train_step_critic

    def update_lower(self):
        if len(self.lower_buffer) >= self.batch_size:
            states, actions, rewards, next_states = self.lower_buffer.sample(self.batch_size)
            rewards = rewards.reshape(-1, 1).astype(np.float32)

            self.train_step_critics(states, actions, rewards, next_states, self.low_actor, self.low_target_critic_1,
                                    self.low_target_critic_2,
                                    self.low_critic_train_fns,
                                    target_noise=tf.random.normal(actions.shape,
                                                                  stddev=0.1 * self.env.action_space.high),
                                    scope='Lower_Policy')

            if self.low_critic_1_optimizer.iterations % self.d == 0:
                self.low_actor_train_fn(states, scope='Lower_Policy', label='Actor')

                # Update target networks
                polyak_average(self.low_actor.variables, self.low_target_actor.variables, self.polyak)
                polyak_average(self.low_critic_1.variables, self.low_target_critic_1.variables, self.polyak)
                polyak_average(self.low_critic_2.variables, self.low_target_critic_2.variables, self.polyak)

    def update_higher(self):
        if len(self.higher_buffer) >= self.batch_size:
            states, goals, actions, rewards, next_states = self.higher_buffer.sample(self.batch_size)
            rewards = rewards.reshape((-1, 1))

            states, goals, actions, rewards, next_states = (tf.convert_to_tensor(states, dtype=tf.float32),
                                                            tf.convert_to_tensor(goals, dtype=tf.float32),
                                                            tf.convert_to_tensor(actions, dtype=tf.float32),
                                                            tf.convert_to_tensor(rewards, dtype=tf.float32),
                                                            tf.convert_to_tensor(next_states, dtype=tf.float32))

            states, goals = self.off_policy_correct(states=states, goals=goals, actions=actions, new_states=next_states)

            self.train_step_critics(states, goals, rewards, next_states, self.high_actor, self.high_target_critic_1,
                                    self.high_target_critic_2,
                                    self.high_critic_train_fns,
                                    target_noise=tf.random.normal(next_states.shape,
                                                                  stddev=0.1 * self.env.observation_space.high),
                                    scope='Higher_Policy')

            if self.high_critic_1_optimizer.iterations % self.d == 0:
                self.high_actor_train_fn(states, scope='Higher_Policy', label='Actor')

                # Update target networks
                polyak_average(self.high_actor.variables, self.high_target_actor.variables, self.polyak)
                polyak_average(self.high_critic_1.variables, self.high_target_critic_1.variables, self.polyak)
                polyak_average(self.high_critic_2.variables, self.high_target_critic_2.variables, self.polyak)

    def learn(self):
        # Collect experiences s_t, g_t, a_t, R_t
        mean_reward = None
        total_steps = 0

        for ep in range(self.max_episodes):
            if ep % self.print_freq == 0 and ep > 0:
                new_mean_reward = np.mean(self.rewards[-self.print_freq - 1:])

                print(f"-------------------------------------------------------")
                print(f"Mean {self.print_freq} Episode Reward: {new_mean_reward}")
                print(f"Total Episodes: {ep}")
                print(f"Total Steps: {total_steps}")
                print(f"-------------------------------------------------------")

                total_steps = 0
                with tf.name_scope('Episodic Information'):
                    with self.summary_writer.as_default():
                        tf.summary.scalar(f'Mean {self.print_freq} Episode Reward', new_mean_reward,
                                          step=ep // self.print_freq)

                # Model saving inspired by Open AI Baseline implementation
                if (mean_reward is None or new_mean_reward >= mean_reward) and self.save_path is not None:
                    print(f"Saving model due to mean reward increase:{mean_reward} -> {new_mean_reward}")
                    print(f'Location: {self.save_path}')
                    mean_reward = new_mean_reward

                    self.low_actor.save_weights(f'{self.save_path}/low/actor')
                    self.low_critic_1.save_weights(f'{self.save_path}/low/critic_1')
                    self.low_critic_2.save_weights(f'{self.save_path}/low/critic_2')
                    self.high_actor.save_weights(f'{self.save_path}/high/actor')
                    self.high_critic_1.save_weights(f'{self.save_path}/high/critic_1')
                    self.high_critic_2.save_weights(f'{self.save_path}/high/critic_2')

            obs = self.env.reset()
            goal = self.get_goal(obs.reshape((1, -1)), noise=True).flatten()
            higher_goal = goal
            higher_obs = []
            higher_actions = []
            higher_reward = 0
            episode_reward = 0
            episode_intrinsic_rewards = 0
            ep_len = 0
            c = 0

            done = False
            while not done:
                if self.render:
                    self.env.render()
                action = self.act(obs.reshape((1, -1)), goal.reshape((1, -1)), noise=True).flatten()
                new_obs, rew, done, info = self.env.step(action)
                new_obs = new_obs.flatten()
                new_goal = self.goal_transition(obs, goal, new_obs)
                episode_reward += rew

                # Goals are treated as additional state information for the low level
                # policy. Store transitions in respective replay buffers
                intrinsic_reward = self.intrinsic_reward(obs, goal, new_obs) * self.low_rew_scale
                self.lower_buffer.add((np.concatenate((obs, goal)), action,
                                       intrinsic_reward,
                                       np.concatenate((new_obs, new_goal)),))
                episode_intrinsic_rewards += intrinsic_reward

                self.update_lower()

                # Fill lists for single higher level transition
                higher_obs.append(obs)
                higher_actions.append(action)
                higher_reward += self.high_rew_scale * rew

                # Only add transitions to the high level replay buffer every c steps
                c += 1
                if c == self.c or done:
                    # Need all higher level transitions to be the same length
                    # fill the rest of this transition with zeros
                    while c < self.c:
                        higher_obs.append(np.full(self.env.observation_space.shape, 0))
                        higher_actions.append(np.full(self.env.action_space.shape, 0))
                        c += 1
                    self.higher_buffer.add((higher_obs, higher_goal, higher_actions, higher_reward, new_obs))

                    self.update_higher()
                    c = 0
                    higher_obs = []
                    higher_actions = []
                    higher_reward = 0
                    goal = self.get_goal(new_obs.reshape((1, -1)), noise=True).flatten()
                    higher_goal = goal

                obs = new_obs
                goal = new_goal

            with tf.name_scope('Episodic Information'):
                with self.summary_writer.as_default():
                    tf.summary.scalar(f'Episode Environment Reward', episode_reward, step=ep)
                    tf.summary.scalar(f'Episode Intrinsic Reward', episode_intrinsic_rewards, step=ep)

            self.rewards.append(episode_reward)
            total_steps += ep_len
Esempio n. 2
0
        loss_G.backward()

        optimizer_G.step()

        ### Discriminator
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        # real loss
        pred_real_A = netD_A(real_A)
        pred_real_B = netD_B(real_B)
        real_loss_A = torch.nn.MSELoss(pred_real_A, target_real)
        real_loss_B = torch.nn.MSELoss(pred_real_B, target_real)

        # fake loss
        fake_A = buffer_A.sample()
        fake_B = buffer_B.sample()
        pred_fake_A = netD_A(fake_A)
        pred_real_B = netD_B(fake_B)
        fake_loss_A = torch.nn.MSELoss(pred_fake_A, target_fake)
        fake_loss_B = torch.nn.MSELoss(pred_fake_B, target_fake)

        # total loss
        loss_D_A = real_loss_A + fake_loss_A
        loss_D_B = real_loss_B + fake_loss_B
        loss_D_A.backward()
        loss_D_B.backward()

        optimizer_D_A.step()
        optimizer_D_B.step()
Esempio n. 3
0
class SelfPlayTrainer:
    def __init__(self,
                 agent,
                 game,
                 buffer_file=None,
                 weights_file=None,
                 n_batches=0):
        self.agent = agent
        self.game = game
        self.replay_buffer = ReplayBuffer()

        if buffer_file is not None:
            self.replay_buffer.buffer = pickle.load(open(buffer_file, "rb"))

        self.current_network = NestedTTTNet()
        self.control_network = NestedTTTNet()

        if weights_file is not None:
            self.control_network.load_state_dict(torch.load(weights_file))

        self.current_network.load_state_dict(self.control_network.state_dict())
        self.control_network.eval()
        self.current_network.train()

        self.agent.update_control_net(self.control_network)

        self.n_batches = n_batches

        self.optim = torch.optim.Adam(self.current_network.parameters(),
                                      lr=.01,
                                      weight_decay=10e-4)

    def generate_self_play_data(self, n_games=100):
        for _ in range(n_games):
            turn_num = 0
            self.game.reset()
            self.agent.reset()
            result = 0
            player_num = 0

            states = []
            move_vectors = []

            while len(self.game.get_valid_moves()) > 0:
                move, move_probs = self.agent.search(self.game.copy(),
                                                     turn_num,
                                                     allotted_playouts=400)

                states.append(self.game.state.tolist())
                move_vectors.append(move_probs)

                result = self.game.make_move(move)
                if not result:
                    self.game.switch_player()
                    self.agent.take_action(move)
                    turn_num += 1
                    player_num = (player_num + 1) % 2

            if not result:
                self.replay_buffer.extend(
                    list(zip(states, move_vectors, zero_gen())))
            else:
                self.replay_buffer.extend(
                    list(
                        zip(states[::-1], move_vectors[::-1],
                            one_neg_one_gen()))[::-1])

    def compare_control_to_train(self):
        self.current_network.eval()
        old_agent = AlphaMCTSAgent(control_net=self.control_network)
        new_agent = AlphaMCTSAgent(control_net=self.current_network)

        agents = [old_agent, new_agent]

        wins = 0
        ties = 0

        game = self.game.copy()

        for game_num in range(100):
            game.reset()
            agents[0].reset()
            agents[1].reset()
            result = 0
            player_num = game_num // 50  #Both take first turn 50 times
            turn_num = 100  #Turn down the temperature

            while len(game.get_valid_moves()) > 0:
                move, _ = agents[player_num].search(game.copy(),
                                                    turn_num,
                                                    allotted_playouts=800)
                _, _ = agents[1 - player_num].search(game.copy(),
                                                     turn_num,
                                                     allotted_playouts=800)

                result = game.make_move(move)
                if not result:
                    game.switch_player()
                    agents[0].take_action(move)
                    agents[1].take_action(move)
                    player_num = (player_num + 1) % 2

            if not result:
                ties += 1
            elif result and player_num == 1:
                wins += 1

            print("After {} games, {} wins and {} ties".format(
                game_num + 1, wins, ties))

        if wins + .5 * ties >= 55:
            print(
                "Challenger network won {} games and tied {} games; it becomes new control network"
                .format(wins, ties))
            torch.save(self.current_network.state_dict(),
                       "control_weights_{}.pth".format(self.n_batches))
            self.control_network.load_state_dict(
                self.current_network.state_dict())
        else:
            print(
                "Challenger network not sufficiently better; {} wins and {} ties"
                .format(wins, ties))

        self.control_network.eval()
        self.current_network.train()

    def train_on_batch(self, batch_size=32):
        if len(self.replay_buffer) < batch_size:
            return

        self.current_network.train()

        sample = self.replay_buffer.sample(batch_size)
        states, probs, rewards = zip(*sample)
        states = torch.FloatTensor(states).requires_grad_(True)
        probs = torch.FloatTensor(probs).requires_grad_(True)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).requires_grad_(True)
        self.optim.zero_grad()

        ps, vs = self.current_network(states)

        loss = torch.nn.functional.mse_loss(
            vs, rewards) - (ps.log() * probs).sum()
        loss.backward()

        self.optim.step()

        self.n_batches += 1

        return loss.item()

    def run(self,
            total_runs=10,
            self_play_games=100,
            training_batches=200,
            batch_size=32):
        losses = []
        for run_num in range(1, total_runs + 1):
            print("Run {} of {}".format(run_num, total_runs))
            for selfplay_num in range(1, self_play_games + 1):
                self.generate_self_play_data(1)
                print("\tFinished self-play game {} of {} (Buffer size {})".
                      format(selfplay_num, self_play_games,
                             len(self.replay_buffer)))
            print("Finished {} self-play games".format(self_play_games))
            for _ in range(training_batches):
                losses.append(self.train_on_batch(batch_size))
                if len(losses) == 5:
                    print("\tLoss for last 5 batches: {}".format(sum(losses)))
                    losses = []

        self.compare_control_to_train()
Esempio n. 4
0
class DDPG:
    def __init__(
        self,
        env,
        gamma=0.99,
        polyak=0.995,
        act_noise=0.1,
        render=False,
        batch_size=32,
        q_lr=1e-3,
        p_lr=1e-4,
        buffer_capacity=5000,
        max_episodes=100,
        save_path=None,
        load_path=None,
        print_freq=1,
        start_steps=10000,
        log_dir='logs/train',
        training=True,
    ):
        self.gamma = gamma
        self.polyak = polyak
        self.act_noise = act_noise
        self.render = render
        self.batch_size = batch_size
        self.p_lr = p_lr
        self.q_lr = q_lr
        self.max_episodes = max_episodes
        self.start_steps = start_steps
        self.actor, self.critic = create_actor_critic(
            env.observation_space.shape[0], env.action_space.shape[0],
            env.action_space.high)
        self.target_actor, self.target_critic = create_actor_critic(
            env.observation_space.shape[0], env.action_space.shape[0],
            env.action_space.high)
        self.target_actor.set_weights(self.actor.get_weights())
        self.target_critic.set_weights(self.critic.get_weights())
        self.env = env
        self.rewards = []
        self.print_freq = print_freq
        self.save_path = save_path

        if training:
            self.buffer = ReplayBuffer(buffer_capacity)
            self.actor_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.p_lr)
            self.critic_optimizer = tf.keras.optimizers.Adam(
                learning_rate=self.q_lr)
            self.summary_writer = tf.summary.create_file_writer(log_dir)
            self.mse = tf.keras.losses.MeanSquaredError()
        if load_path is not None:
            self.actor.load_weights(f'{load_path}/actor')
            self.critic.load_weights(f'{load_path}/critic')

    @tf.function
    def train_step(self, states, actions, targets):
        with tf.GradientTape() as tape:
            action_predictions = self.actor(states)
            q_values = self.critic([states, action_predictions])
            policy_loss = -tf.reduce_mean(q_values)
        actor_gradients = tape.gradient(policy_loss,
                                        self.actor.trainable_variables)
        self.actor_optimizer.apply_gradients(
            zip(actor_gradients, self.actor.trainable_variables))

        with tf.GradientTape() as tape:
            q_values = self.critic([states, actions])
            mse_loss = self.mse(q_values, targets)
        critic_gradients = tape.gradient(mse_loss,
                                         self.critic.trainable_variables)
        self.critic_optimizer.apply_gradients(
            zip(critic_gradients, self.critic.trainable_variables))

        with self.summary_writer.as_default():
            tf.summary.scalar('Policy Loss',
                              policy_loss,
                              step=self.critic_optimizer.iterations)
            tf.summary.scalar('MSE Loss',
                              mse_loss,
                              step=self.critic_optimizer.iterations)
            tf.summary.scalar('Estimated Q Value',
                              tf.reduce_mean(q_values),
                              step=self.critic_optimizer.iterations)

    def update(self):
        if len(self.buffer) >= self.batch_size:
            # Sample random minibatch of N transitions
            states, actions, rewards, next_states, dones = self.buffer.sample(
                self.batch_size)
            dones = dones.reshape(-1, 1)
            rewards = rewards.reshape(-1, 1)

            # Set the target for learning
            target_action_preds = self.target_actor(next_states)
            target_q_values = self.target_critic(
                [next_states, target_action_preds])
            targets = rewards + self.gamma * target_q_values * (1 - dones)

            # update critic by minimizing the MSE loss
            # update the actor policy using the sampled policy gradient
            self.train_step(states, actions, targets)

            # Update target networks
            polyak_average(self.actor.variables, self.target_actor.variables,
                           self.polyak)
            polyak_average(self.critic.variables, self.target_critic.variables,
                           self.polyak)

    def act(self, obs, noise=False):
        # Initialize a random process N for action exploration
        norm_dist = tf.random.normal(self.env.action_space.shape,
                                     stddev=self.act_noise)

        action = self.actor(np.expand_dims(obs, axis=0))
        action = np.clip(action.numpy() + (norm_dist.numpy() if noise else 0),
                         a_min=self.env.action_space.low,
                         a_max=self.env.action_space.high)
        return action

    def learn(self):
        mean_reward = None
        total_steps = 0
        overall_steps = 0
        for ep in range(self.max_episodes):
            if ep % self.print_freq == 0 and ep > 0:
                new_mean_reward = np.mean(self.rewards[-self.print_freq - 1:])

                print(
                    f"-------------------------------------------------------")
                print(
                    f"Mean {self.print_freq} Episode Reward: {new_mean_reward}"
                )
                print(f"Mean Steps: {total_steps / self.print_freq}")
                print(f"Total Episodes: {ep}")
                print(f"Total Steps: {overall_steps}")
                print(
                    f"-------------------------------------------------------")

                total_steps = 0
                with self.summary_writer.as_default():
                    tf.summary.scalar(f'Mean {self.print_freq} Episode Reward',
                                      new_mean_reward,
                                      step=ep)

                # Model saving inspired by Open AI Baseline implementation
                if (mean_reward is None or new_mean_reward >= mean_reward
                    ) and self.save_path is not None:
                    print(
                        f"Saving model due to mean reward increase:{mean_reward} -> {new_mean_reward}"
                    )
                    print(f'Location: {self.save_path}')
                    mean_reward = new_mean_reward

                    self.actor.save_weights(f'{self.save_path}/actor')
                    self.critic.save_weights(f'{self.save_path}/critic')

            # Receive initial observation state s_1
            obs = self.env.reset()
            done = False
            episode_reward = 0
            ep_len = 0
            while not done:
                # Display the environment
                if self.render:
                    self.env.render()

                # Execute action and observe reward and observe new state
                if self.start_steps > 0:
                    self.start_steps -= 1
                    action = self.env.action_space.sample()
                else:
                    # Select action according to policy and exploration noise
                    action = self.act(obs, noise=True).flatten()
                new_obs, rew, done, info = self.env.step(action)
                new_obs = new_obs.flatten()
                episode_reward += rew

                # Store transition in R
                self.buffer.add((obs, action, rew, new_obs, done))

                # Perform a single learning step
                self.update()

                obs = new_obs
                ep_len += 1

            with self.summary_writer.as_default():
                tf.summary.scalar(f'Episode Reward', episode_reward, step=ep)

            self.rewards.append(episode_reward)
            total_steps += ep_len
            overall_steps += ep_len
Esempio n. 5
0
class DDPGAgent():
    
    def __init__(self, state_size, action_size, num_agents):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(RANDOM_SEED)
        self.num_agents = num_agents

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size).to(device)
        self.actor_target = Actor(state_size, action_size).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size).to(device)
        self.critic_target = Critic(state_size, action_size).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(), lr=LR_CRITIC, weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OUNoise(action_size)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)
        
        # Directory where to save the model
        self.model_dir = os.getcwd() + "/DDPG/saved_models"
        os.makedirs(self.model_dir, exist_ok=True)

    def step(self, states, actions, rewards, next_states, dones):
        for i in range(self.num_agents):
            self.memory.add(states[i], actions[i], rewards[i], next_states[i], dones[i])

        if len(self.memory) > BATCH_SIZE:
            experiences = self.memory.sample()
            self.learn(experiences, GAMMA)
        
    def act(self, states, add_noise=True):
        states = torch.from_numpy(states).float().to(device)

        actions = np.zeros((self.num_agents, self.action_size))
        self.actor_local.eval()
        with torch.no_grad():
            for i, state in enumerate(states):
                actions[i, :] = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        
        if add_noise:
            actions += self.noise.sample()
        
        return np.clip(actions, -1, 1)

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)  # adds gradient clipping to stabilize learning
        self.critic_optimizer.step()
        
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)
        
    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
            
    def save_model(self):
        torch.save(
            self.actor_local.state_dict(), 
            os.path.join(self.model_dir, 'actor_params.pth')
        )
        torch.save(
            self.actor_optimizer.state_dict(), 
            os.path.join(self.model_dir, 'actor_optim_params.pth')
        )
        torch.save(
            self.critic_local.state_dict(), 
            os.path.join(self.model_dir, 'critic_params.pth')
        )
        torch.save(
            self.critic_optimizer.state_dict(), 
            os.path.join(self.model_dir, 'critic_optim_params.pth')
        )

    def load_model(self):
        """Loads weights from saved model."""
        self.actor_local.load_state_dict(
            torch.load(os.path.join(self.model_dir, 'actor_params.pth'))
        )
        self.actor_optimizer.load_state_dict(
            torch.load(os.path.join(self.model_dir, 'actor_optim_params.pth'))
        )
        self.critic_local.load_state_dict(
            torch.load(os.path.join(self.model_dir, 'critic_params.pth'))
        )
        self.critic_optimizer.load_state_dict(
            torch.load(os.path.join(self.model_dir, 'critic_optim_params.pth'))
        )
Esempio n. 6
0
class Agent(BaseAgent):
    def __init__(self, env, **kwargs):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.obs_space = env.observation_space
        self.action_space = env.action_space
        super(Agent, self).__init__(env.action_space)
        mask = kwargs.get('mask', 2)
        mask_hi = kwargs.get('mask_hi', 19)
        self.rule = kwargs.get('rule', 'c')
        self.danger = kwargs.get('danger', 0.9)
        self.bus_thres = kwargs.get('threshold', 0.1)
        self.max_low_len = kwargs.get('max_low_len', 19)
        self.converter = graphGoalConverter(env, mask, mask_hi, self.danger,
                                            self.device, self.rule)
        self.thermal_limit = env._thermal_limit_a
        self.convert_obs = self.converter.convert_obs
        self.action_dim = self.converter.n
        self.order_dim = len(self.converter.masked_sorted_sub)
        self.node_num = env.dim_topo
        self.delay_step = 2
        self.update_step = 0
        self.k_step = 1
        self.nheads = kwargs.get('head_number', 8)
        self.target_update = kwargs.get('target_update', 1)
        self.hard_target = kwargs.get('hard_target', False)
        self.use_order = (self.rule == 'o')

        self.gamma = kwargs.get('gamma', 0.99)
        self.tau = kwargs.get('tau', 1e-3)
        self.dropout = kwargs.get('dropout', 0.)
        self.memlen = kwargs.get('memlen', int(1e5))
        self.batch_size = kwargs.get('batch_size', 128)
        self.update_start = self.batch_size * 8
        self.actor_lr = kwargs.get('actor_lr', 5e-5)
        self.critic_lr = kwargs.get('critic_lr', 5e-5)
        self.embed_lr = kwargs.get('embed_lr', 5e-5)
        self.alpha_lr = kwargs.get('alpha_lr', 5e-5)

        self.state_dim = kwargs.get('state_dim', 128)
        self.n_history = kwargs.get('n_history', 6)
        self.input_dim = self.converter.n_feature * self.n_history

        print(
            f'N: {self.node_num}, O: {self.input_dim}, S: {self.state_dim}, A: {self.action_dim}, ({self.order_dim})'
        )
        print(kwargs)
        self.emb = EncoderLayer(self.input_dim, self.state_dim, self.nheads,
                                self.node_num, self.dropout).to(self.device)
        self.temb = EncoderLayer(self.input_dim, self.state_dim, self.nheads,
                                 self.node_num, self.dropout).to(self.device)
        self.Q = DoubleSoftQ(self.state_dim, self.nheads, self.node_num,
                             self.action_dim, self.use_order, self.order_dim,
                             self.dropout).to(self.device)
        self.tQ = DoubleSoftQ(self.state_dim, self.nheads, self.node_num,
                              self.action_dim, self.use_order, self.order_dim,
                              self.dropout).to(self.device)
        self.actor = Actor(self.state_dim, self.nheads, self.node_num,
                           self.action_dim, self.use_order, self.order_dim,
                           self.dropout).to(self.device)

        # copy parameters
        self.tQ.load_state_dict(self.Q.state_dict())
        self.temb.load_state_dict(self.emb.state_dict())

        # entropy
        self.target_entropy = -self.action_dim * 3 if not self.use_order else -3 * (
            self.action_dim + self.order_dim)
        self.log_alpha = torch.FloatTensor([-3]).to(self.device)
        self.log_alpha.requires_grad = True

        # optimizers
        self.Q.optimizer = optim.Adam(self.Q.parameters(), lr=self.critic_lr)
        self.actor.optimizer = optim.Adam(self.actor.parameters(),
                                          lr=self.actor_lr)
        self.emb.optimizer = optim.Adam(self.emb.parameters(),
                                        lr=self.embed_lr)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=self.alpha_lr)

        self.memory = ReplayBuffer(max_size=self.memlen)
        self.Q.eval()
        self.tQ.eval()
        self.emb.eval()
        self.temb.eval()
        self.actor.eval()

    def is_safe(self, obs):
        for ratio, limit in zip(obs.rho, self.thermal_limit):
            # Seperate big line and small line
            if (limit < 400.00
                    and ratio >= self.danger - 0.05) or ratio >= self.danger:
                return False
        return True

    def load_mean_std(self, mean, std):
        self.state_mean = mean
        self.state_std = std.masked_fill(std < 1e-5, 1.)
        self.state_mean[0, sum(self.obs_space.shape[:20]):] = 0
        self.state_std[0, sum(self.action_space.shape[:20]):] = 1

    def state_normalize(self, s):
        s = (s - self.state_mean) / self.state_std
        return s

    def reset(self, obs):
        self.converter.last_topo = np.ones(self.node_num, dtype=int)
        self.topo = None
        self.goal = None
        self.goal_list = []
        self.low_len = -1
        self.adj = None
        self.stacked_obs = []
        self.low_actions = []
        self.save = False

    def cache_stat(self):
        cache = {
            'last_topo': self.converter.last_topo,
            'topo': self.topo,
            'goal': self.goal,
            'goal_list': self.goal_list,
            'low_len': self.low_len,
            'adj': self.adj,
            'stacked_obs': self.stacked_obs,
            'low_actions': self.low_actions,
            'save': self.save,
        }
        return cache

    def load_cache_stat(self, cache):
        self.converter.last_topo = cache['last_topo']
        self.topo = cache['topo']
        self.goal = cache['goal']
        self.goal_list = cache['goal_list']
        self.low_len = cache['low_len']
        self.adj = cache['adj']
        self.stacked_obs = cache['stacked_obs']
        self.low_actions = cache['low_actions']
        self.save = cache['save']

    def hash_goal(self, goal):
        hashed = ''
        for i in goal.view(-1):
            hashed += str(int(i.item()))
        return hashed

    def stack_obs(self, obs):
        obs_vect = obs.to_vect()
        obs_vect = torch.FloatTensor(obs_vect).unsqueeze(0)
        obs_vect, self.topo = self.convert_obs(self.state_normalize(obs_vect))
        if len(self.stacked_obs) == 0:
            for _ in range(self.n_history):
                self.stacked_obs.append(obs_vect)
        else:
            self.stacked_obs.pop(0)
            self.stacked_obs.append(obs_vect)
        self.adj = (torch.FloatTensor(obs.connectivity_matrix()) +
                    torch.eye(int(obs.dim_topo))).to(self.device)
        self.converter.last_topo = np.where(obs.topo_vect == -1,
                                            self.converter.last_topo,
                                            obs.topo_vect)

    def reconnect_line(self, obs):
        # if the agent can reconnect powerline not included in controllable substation, return action
        # otherwise, return None
        dislines = np.where(obs.line_status == False)[0]
        for i in dislines:
            act = None
            if obs.time_next_maintenance[
                    i] != 0 and i in self.converter.lonely_lines:
                sub_or = self.action_space.line_or_to_subid[i]
                sub_ex = self.action_space.line_ex_to_subid[i]
                if obs.time_before_cooldown_sub[sub_or] == 0:
                    act = self.action_space(
                        {'set_bus': {
                            'lines_or_id': [(i, 1)]
                        }})
                if obs.time_before_cooldown_sub[sub_ex] == 0:
                    act = self.action_space(
                        {'set_bus': {
                            'lines_ex_id': [(i, 1)]
                        }})
                if obs.time_before_cooldown_line[i] == 0:
                    status = self.action_space.get_change_line_status_vect()
                    status[i] = True
                    act = self.action_space({'change_line_status': status})
                if act is not None:
                    return act
        return None

    def get_current_state(self):
        return torch.cat(self.stacked_obs + [self.topo], dim=-1)

    def act(self, obs, reward, done):
        sample = (reward is None)
        self.stack_obs(obs)
        is_safe = self.is_safe(obs)
        self.save = False

        # reconnect powerline when the powerline in uncontrollable substations is disconnected
        if False in obs.line_status:
            act = self.reconnect_line(obs)
            if act is not None:
                return act

        # generate goal if it is initial or previous goal has been reached
        if self.goal is None or (not is_safe and self.low_len == -1):
            goal, bus_goal, low_actions, order, Q1, Q2 = self.generate_goal(
                sample, obs, not sample)
            if len(low_actions) == 0:
                act = self.action_space()
                if self.goal is None:
                    self.update_goal(goal, bus_goal, low_actions, order, Q1,
                                     Q2)
                return self.action_space()
            self.update_goal(goal, bus_goal, low_actions, order, Q1, Q2)

        act = self.pick_low_action(obs)
        return act

    def pick_low_action(self, obs):
        # Safe and there is no queued low actions, just do nothing
        if self.is_safe(obs) and self.low_len == -1:
            act = self.action_space()
            return act

        # optimize low actions every step
        self.low_actions = self.optimize_low_actions(obs, self.low_actions)
        self.low_len += 1

        # queue has been empty after optimization. just do nothing
        if len(self.low_actions) == 0:
            act = self.action_space()
            self.low_len = -1

        # normally execute low action from low actions queue
        else:
            sub_id, new_topo = self.low_actions.pop(0)[:2]
            act = self.converter.convert_act(sub_id, new_topo, obs.topo_vect)

        # When it meets maximum low action execution time, log and reset
        if self.max_low_len <= self.low_len:
            self.low_len = -1
        return act

    def high_act(self, stacked_state, adj, sample=True):
        order, Q1, Q2 = None, 0, 0
        with torch.no_grad():
            # stacked_state # B, N, F
            stacked_t, stacked_x = stacked_state[...,
                                                 -1:], stacked_state[..., :-1]
            emb_input = stacked_x
            state = self.emb(emb_input, adj).detach()
            actor_input = [state, stacked_t.squeeze(-1)]
            if sample:
                action, std = self.actor.sample(actor_input, adj)
                if self.use_order:
                    action, order = action
                critic_input = action
                Q1, Q2 = self.Q(state, critic_input, adj, order)
                Q1, Q2 = Q1.detach()[0].item(), Q2.detach()[0].item()
                if self.use_order:
                    std, order_std = std
            else:
                action = self.actor.mean(actor_input, adj)
                if self.use_order:
                    action, order = action
        if order is not None: order = order.detach().cpu()
        return action.detach().cpu(), order, Q1, Q2

    def make_candidate_goal(self, stacked_state, adj, sample, obs):
        goal, order, Q1, Q2 = self.high_act(stacked_state, adj, sample)
        bus_goal = torch.zeros_like(goal).long()
        bus_goal[goal > self.bus_thres] = 1
        low_actions = self.converter.plan_act(
            bus_goal, obs.topo_vect, order[0] if order is not None else None)
        low_actions = self.optimize_low_actions(obs, low_actions)
        return goal, bus_goal, low_actions, order, Q1, Q2

    def generate_goal(self, sample, obs, nosave=False):
        stacked_state = self.get_current_state().to(self.device)
        adj = self.adj.unsqueeze(0)
        goal, bus_goal, low_actions, order, Q1, Q2 = self.make_candidate_goal(
            stacked_state, adj, sample, obs)
        return goal, bus_goal, low_actions, order, Q1, Q2

    def update_goal(self, goal, bus_goal, low_actions, order=None, Q1=0, Q2=0):
        self.order = order
        self.goal = goal
        self.bus_goal = bus_goal
        self.low_actions = low_actions
        self.low_len = 0
        self.save = True
        self.goal_list.append(self.hash_goal(bus_goal))

    def optimize_low_actions(self, obs, low_actions):
        # remove overlapped action
        optimized = []
        cooldown_list = obs.time_before_cooldown_sub
        if self.max_low_len != 1 and self.rule == 'c':
            low_actions = self.converter.heuristic_order(obs, low_actions)
        for low_act in low_actions:
            sub_id, sub_goal = low_act[:2]
            sub_goal, same = self.converter.inspect_act(
                sub_id, sub_goal, obs.topo_vect)
            if not same:
                optimized.append((sub_id, sub_goal, cooldown_list[sub_id]))

        # sort by cooldown_sub
        if self.max_low_len != 1 and self.rule != 'o':
            optimized = sorted(optimized, key=lambda x: x[2])

        # if current action has cooldown, then discard
        if len(optimized) > 0 and optimized[0][2] > 0:
            optimized = []
        return optimized

    def append_sample(self, s, m, a, r, s2, m2, d, order):
        if self.use_order:
            self.memory.append((s, m, a, r, s2, m2, int(d), order))
        else:
            self.memory.append((s, m, a, r, s2, m2, int(d)))

    def unpack_batch(self, batch):
        if self.use_order:
            states, adj, actions, rewards, states2, adj2, dones, orders = list(
                zip(*batch))
            orders = torch.cat(orders, 0)
        else:
            states, adj, actions, rewards, states2, adj2, dones = list(
                zip(*batch))
        states = torch.cat(states, 0)
        states2 = torch.cat(states2, 0)
        adj = torch.stack(adj, 0)
        adj2 = torch.stack(adj2, 0)
        actions = torch.cat(actions, 0)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        dones = torch.FloatTensor(dones).unsqueeze(1)
        if self.use_order:
            return states.to(self.device), adj.to(self.device), actions.to(self.device), rewards.to(self.device), \
                states2.to(self.device), adj2.to(self.device), dones.to(self.device), orders.to(self.device)
        else:
            return states.to(self.device), adj.to(self.device), actions.to(self.device), \
                rewards.to(self.device), states2.to(self.device), adj2.to(self.device), dones.to(self.device)

    def update(self):
        self.update_step += 1
        batch = self.memory.sample(self.batch_size)
        orders = None
        if self.use_order:
            stacked_states, adj, actions, rewards, stacked_states2, adj2, dones, orders = self.unpack_batch(
                batch)
        else:
            stacked_states, adj, actions, rewards, stacked_states2, adj2, dones = self.unpack_batch(
                batch)

        self.Q.train()
        self.emb.train()
        self.actor.eval()

        # critic loss
        stacked_t, stacked_x = stacked_states[...,
                                              -1:], stacked_states[..., :-1]
        stacked2_t, stacked2_x = stacked_states2[..., -1:], stacked_states2[
            ..., :-1]
        emb_input = stacked_x
        emb_input2 = stacked2_x
        states = self.emb(emb_input, adj)
        states2 = self.emb(emb_input2, adj2)
        actor_input2 = [states2, stacked2_t.squeeze(-1)]
        with torch.no_grad():
            tstates2 = self.temb(emb_input2, adj2).detach()
            action2, log_pi2 = self.actor.rsample(actor_input2, adj2)
            order2 = None
            if self.use_order:
                action2, order2 = action2
                log_pi2 = log_pi2[0] + log_pi2[1]
            critic_input2 = action2
            targets = self.tQ.min_Q(tstates2, critic_input2, adj2,
                                    order2) - self.log_alpha.exp() * log_pi2

        targets = rewards + (1 - dones) * self.gamma * targets.detach()

        critic_input = actions
        predQ1, predQ2 = self.Q(states, critic_input, adj, orders)

        Q1_loss = F.mse_loss(predQ1, targets)
        Q2_loss = F.mse_loss(predQ2, targets)

        loss = Q1_loss + Q2_loss
        self.Q.optimizer.zero_grad()
        self.emb.optimizer.zero_grad()
        loss.backward()
        self.emb.optimizer.step()
        self.Q.optimizer.step()

        self.Q.eval()

        if self.update_step % self.delay_step == 0:
            # actor loss
            self.actor.train()
            states = self.emb(emb_input, adj)
            actor_input = [states, stacked_t.squeeze(-1)]
            action, log_pi = self.actor.rsample(actor_input, adj)
            order = None
            if self.use_order:
                action, order = action
                log_pi = log_pi[0] + log_pi[1]
            critic_input = action
            actor_loss = (
                self.log_alpha.exp() * log_pi -
                self.Q.min_Q(states, critic_input, adj, order)).mean()

            self.emb.optimizer.zero_grad()
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.emb.optimizer.step()
            self.actor.optimizer.step()

            self.actor.eval()

            # target update
            if self.hard_target:
                self.tQ.load_state_dict(self.Q.state_dict())
                self.temb.load_state_dict(self.emb.state_dict())
            else:
                for tp, p in zip(self.tQ.parameters(), self.Q.parameters()):
                    tp.data.copy_(self.tau * p + (1 - self.tau) * tp)
                for tp, p in zip(self.temb.parameters(),
                                 self.emb.parameters()):
                    tp.data.copy_(self.tau * p + (1 - self.tau) * tp)

            # alpha loss
            alpha_loss = self.log_alpha * (-log_pi.detach() -
                                           self.target_entropy).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        self.emb.eval()

        return predQ1.detach().mean().item(), predQ2.detach().mean().item()

    def save_model(self, path, name):
        torch.save(self.actor.state_dict(),
                   os.path.join(path, f'{name}_actor.pt'))
        torch.save(self.emb.state_dict(), os.path.join(path, f'{name}_emb.pt'))
        torch.save(self.Q.state_dict(), os.path.join(path, f'{name}_Q.pt'))

    def load_model(self, path, name=None):
        head = ''
        if name is not None:
            head = name + '_'
        self.actor.load_state_dict(
            torch.load(os.path.join(path, f'{head}actor.pt'),
                       map_location=self.device))
        self.emb.load_state_dict(
            torch.load(os.path.join(path, f'{head}emb.pt'),
                       map_location=self.device))
        self.Q.load_state_dict(
            torch.load(os.path.join(path, f'{head}Q.pt'),
                       map_location=self.device))