コード例 #1
0
class AgentDDPG:
    def __init__(self, state_size, action_size, seed):
        """

        :state_size: size of the state vector
        :action_size: size of the action vector
        """

        self.state_size = state_size
        self.action_size = action_size
        self.t_step = 0
        self.score = 0.0
        self.best = 0.0
        self.seed = seed
        self.total_reward = 0.0
        self.count = 0
        self.learning_rate_actor = 0.0001
        self.learning_rate_critic = 0.001
        self.batch_size = 128
        self.update_every = 1

        # Instances of the policy function or actor and the value function or critic
        # Actor critic with Advantage

        # Actor local and target network definitions
        self.actor_local = Actor(self.state_size, self.action_size,
                                 self.seed).to(device)

        self.actor_target = Actor(self.state_size, self.action_size,
                                  self.seed).to(device)

        # Critic local and target
        self.critic_local = Critic(self.state_size, self.action_size,
                                   self.seed).to(device)

        self.critic_target = Critic(self.state_size, self.action_size,
                                    self.seed).to(device)
        # Actor Optimizer
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=self.learning_rate_actor)

        # Critic Optimizer
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=self.learning_rate_critic)

        # Make sure local and target start with the same weights
        self.actor_target.load_state_dict(self.actor_local.state_dict())
        self.critic_target.load_state_dict(self.critic_local.state_dict())

        # Initialize the Gaussin Noise process
        self.exploration_mu = 0
        self.exploration_theta = 0.15
        self.exploration_sigma = 0.2
        self.noise = OUNoise(self.action_size, self.exploration_mu,
                             self.exploration_theta, self.exploration_sigma)

        # Initialize the Replay Memory
        self.buffer_size = 1000000
        self.memory = ReplayBuffer(self.buffer_size, self.batch_size)

        # Parameters for the Algorithm
        self.gamma = 0.99  # Discount factor
        self.tau = 0.001  # Soft update for target parameters Actor Critic with Advantage

    # Actor interact with the environment through the step
    def step(self, state, action, reward, next_state, done):
        # Add to the total reward the reward of this time step
        self.total_reward += reward
        # Increase your count based on the number of rewards
        # received in the episode
        self.count += 1
        # Stored experience tuple in the replay buffer
        self.memory.add(state, action, reward, next_state, done)

        # Learn every update_times time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:

            # Check to see if you have enough to produce a batch
            # and learn from it

            if len(self.memory) > self.batch_size:
                experiences = self.memory.sample()
                # Train the networks using the experiences
                self.learn(experiences)

        # Roll over last state action (not needed)
        # self.last_state = next_state

    # Actor determines what to do based on the policy
    def act(self, state):
        # Given a state return the action recommended by the policy
        # Reshape the state to fit the torch tensor input
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)

        # Pass the state to the actor local model to get an action
        # recommend for the policy in a state
        # set the actor_local model to predict not to train
        self.actor_local.eval()
        # set the model so this operation is not counted in the
        # gradiant calculation.
        with torch.no_grad():
            actions = self.actor_local(state)
        # set the model back to training mode
        self.actor_local.train()

        # Because we are exploring we add some noise to the
        # action vector
        return list(actions.detach().numpy().reshape(4, ) +
                    self.noise.sample())

    # This is the Actor learning logic called when the agent
    # take a step to learn
    def learn(self, experiences):
        """
        Learning means that the networks parameters needs to be updated
        Using the experineces batch.
        Network learns from experiences not form interaction with the
        environment
        """

        # Reshape the experience tuples in separate arrays of states, actions
        # rewards, next_state, done
        # Your are converting every member of the tuple in a column or vector
        states = np.vstack([e.state for e in experiences if e is not None])
        actions = np.array([e.action for e in experiences
                            if e is not None]).astype(np.float32).reshape(
                                -1, self.action_size)
        rewards = np.array([e.reward for e in experiences if e is not None
                            ]).astype(np.float32).reshape(-1, 1)
        dones = np.array([e.done for e in experiences
                          if e is not None]).astype(np.uint8).reshape(-1, 1)
        next_states = np.vstack(
            [e.next_state for e in experiences if e is not None])

        # Now reshape the numpy arrays for states, actions and next_states to torch tensors
        # rewards and dones does not need to be tensors.
        states = torch.from_numpy(states).float().unsqueeze(0).to(device)
        actions = torch.from_numpy(actions).float().unsqueeze(0).to(device)
        next_states = torch.from_numpy(next_states).float().unsqueeze(0).to(
            device)

        # Firs we pass a batch of next states to the actor so it tell us what actions
        # to execute, we use the actor target network instead of the actor local network
        # because of the advantage principle

        # set the target network to predict because this is not part of the training, this model
        # weights are alter by a soft update not by an optimizer
        self.actor_target.eval()
        with torch.no_grad():
            next_state_actions = self.actor_target(next_states).detach()
        self.actor_target.train()

        # The critic evaluates the actions taking by the actor in the next state and generates the
        # Q(a,s) value of the next state taking those actions. These action, next_state tuple comes from the
        # ReplayBuffer not from interacting with the environment.
        # Remember the Critic or q_value function inputs is states, actions
        # We calculate the q_targets of the next state. We will use this to calculate the current
        # state q_value using the bellman equation.

        # set the target network to predict because this is not part of the training, this model
        # weights are alter by a soft update not by an optimizer
        self.critic_target.eval()
        with torch.no_grad():
            q_targets_next_state_action_values = self.critic_target(
                next_states, next_state_actions).detach()
        self.actor_target.train()

        # With the next state q_value that is a vector of action values Q(s,a) of a random selected
        # next_states from the replay buffer. We calculate the CURRENT state target Q(s,a).
        # using the TD one-step Sarsa equations and the q_target_next value we got from the critic_target net
        # We make terminal states target Q(s,a) 0 and Non terminal the Q_targtes value
        # This is done to train the critic_local model in a supervise learning fashion, this is the target values.
        q_targets = torch.from_numpy(
            rewards + self.gamma * q_targets_next_state_action_values.numpy() *
            (1 - dones)).float()

        # --- Optimize the local Critic Model ----#

        # Here we start the supervise training process of the critic_local network
        # we pass a bunch of states actions samples it produces the expected output
        # q_value of each action we passed.
        q_expected = self.critic_local(states, actions)

        # Clear grad buffer values in preparation.
        self.critic_optimizer.zero_grad()

        # loss function for the critic_local model mean square of the difference
        # between the q_expected value and the q_target value.
        critic_loss = F.smooth_l1_loss(q_expected, q_targets)
        critic_loss.backward(retain_graph=True)

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)

        # optimize the critic_local model using the optimizer defined for the critic
        # In the init function of this class
        self.critic_optimizer.step()

        # --- Optimize the local Actor Model ---#

        # Get the actor actions using the experience buffer states
        actor_actions = self.actor_local(states)

        # Use as a loss the negative sum of the q_values produce by the optimized critic local model given the
        # action of the actor_local model obtain using the states of the sampled buffer.
        loss_actor = -1 * torch.sum(
            self.critic_local.forward(states, actor_actions))

        # Set the model gradients to zero in preparation
        self.actor_optimizer.zero_grad()

        # Back propagate
        loss_actor.backward()

        # optimize the actor_local model using the optimizer defined for the actor
        # In the init function of this class
        self.actor_optimizer.step()

        # Soft-update target models
        self.soft_update(self.critic_local, self.critic_target)
        self.soft_update(self.actor_local, self.actor_target)

    def soft_update(self, local_model, target_model):

        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def get_episode_score(self):
        """
        Calculate the episode scores
        :return: None
        """
        # Update score and best score
        self.score = self.total_reward / float(
            self.count) if self.count else 0.0
        if self.score > self.best:
            self.best = self.score

    def save_model_weights(self):
        torch.save(self.actor_local.state_dict(), './checkpoints.pkl')
コード例 #2
0
class ActorCritic(Model):
    def __init__(self,
                 observation_space_size,
                 action_space_size,
                 name=None,
                 env_name=None,
                 model_config=None,
                 play_mode=False):

        if name is None:
            name = "Unnamed-ActorCritic"
        super(ActorCritic,
              self).__init__(observation_space_size, action_space_size, name,
                             env_name, model_config, play_mode)

    def build_model(self):

        self.policy_net = Actor(self.observation_space_size,
                                self.action_space_size)
        self.critic_net = Critic(self.observation_space_size)

        if self.model_config is None:
            self.gamma = 0.99

            self.actor_optimizer = optim.Adam(self.policy_net.parameters())
            self.actor_loss = nn.MSELoss()

            self.critic_optimizer = optim.Adam(self.critic_net.parameters())
            self.critic_loss = nn.MSELoss()

            self.get_epsilon = self.get_epsilon_default
        else:
            pass

    def save_checkpoint(self, n=0, filepath=None):
        """
        n - number of epoch / episode or whatever is used for enumeration
        """

        # TO DO: ADD OTHER RELEVANT PARAMETERS
        checkpoint = {
            'policy': self.policy_net.state_dict(),
            'critic': self.critic_net.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        super(ActorCritic, self).save_checkpoint(n, filepath, checkpoint)

    def load_checkpoint(self, filepath):
        # TO DO: ADD OTHER RELEVANT parameters
        checkpoint = torch.load(filepath)
        self.policy_net.load_state_dict(checkpoint['policy'])
        self.critic_net.load_state_dict(checkpoint['critic'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

    def prepare_sample(self, sample):
        sample = np.array(sample)
        states = torch.tensor(sample[:, 0], dtype=torch.float32)
        actions = torch.tensor(sample[:, 1], dtype=torch.float32)
        rewards = torch.tensor(sample[:, 2], dtype=torch.float32)
        next_states = torch.tensor(sample[:, 3], dtype=torch.float32)
        dones = torch.tensor(sample[:, 4], dtype=torch.int32)

        return states, actions, rewards, next_states, dones

    def critic_update(self, V, V_target):
        self.critic_optimizer.zero_grad()
        critic_loss = self.critic_loss(V, V_target)
        critic_loss.backward()
        self.critic_optimizer.step()

        return critic_loss.item()

    def actor_update(self, advantages, actions, mus):
        self.actor_optimizer.zero_grad()
        actor_loss = self.actor_loss(actions, mus)
        gradient_term = advantages * actor_loss
        gradient_term.backward()
        self.actor_optimizer.step()

        return actor_loss.item()

    def update(self, sample, prepare_state=None):
        actor_running_loss = []
        critic_running_loss = []

        for state, action, reward, next_state, done in sample:
            if prepare_state is not None:
                state = prepare_state(state)
                next_state = prepare_state(next_state)

            state = torch.tensor(state, dtype=torch.float32)
            next_state = torch.tensor(next_state, dtype=torch.float32)
            action = torch.tensor(action, dtype=torch.float32)

            # Update Critic
            V = self.critic_net.forward(state)
            V_target = torch.tensor([reward], dtype=torch.float32)
            if done is False:
                V_target += self.gamma * self.critic_net.forward(next_state)

            critic_loss = self.critic_update(V, V_target)
            critic_running_loss.append(critic_loss)

            # Update Actor
            advantage = (V_target - V).detach()
            mu = self.policy_net(state)

            actor_loss = self.actor_update(advantage, action, mu)
            actor_running_loss.append(actor_loss)

        return actor_running_loss, critic_running_loss

    def batch_update(self, sample, prepare_state=None):
        actor_running_loss = []
        critic_running_loss = []

        states, actions, rewards, next_states, dones = self.prepare_sample(
            sample)

        # Update Critic
        V = self.critic_net.forward(states)
        V_target = rewards + self.gamma * self.critic_net.forward(
            next_states) * (1 - dones)

        critic_loss = self.critic_update(V, V_target)
        critic_running_loss.append(critic_loss)

        # Update Actor
        advantage = (V_target - V).detach()
        mu = self.policy_net(states)

        actor_loss = self.actor_update(advantage, actions, mu)
        actor_running_loss.append(actor_loss)

        return actor_running_loss, critic_running_loss
コード例 #3
0
class Agent(object):
    def __init__(self, n_states, n_actions, lr_actor, lr_critic, tau, gamma,
                 mem_size, actor_l1_size, actor_l2_size, critic_l1_size,
                 critic_l2_size, batch_size):

        self.gamma = gamma
        self.tau = tau
        self.memory = ReplayBuffer(mem_size, n_states, n_actions)
        self.batch_size = batch_size

        self.actor = Actor(lr_actor, n_states, n_actions, actor_l1_size,
                           actor_l2_size)
        self.critic = Critic(lr_critic, n_states, n_actions, critic_l1_size,
                             critic_l2_size)

        self.target_actor = Actor(lr_actor, n_states, n_actions, actor_l1_size,
                                  actor_l2_size)
        self.target_critic = Critic(lr_critic, n_states, n_actions,
                                    critic_l1_size, critic_l2_size)

        self.noise = OUActionNoise(mu=np.zeros(n_actions), sigma=0.005)

        self.update_network_parameters(tau=1)

    def choose_action(self, observation):
        self.actor.eval()
        observation = torch.tensor(observation,
                                   dtype=torch.float).to(self.actor.device)
        mu = self.actor.forward(observation).to(self.actor.device)

        # add noise to action - for exploration
        mu_prime = mu + torch.tensor(self.noise(), dtype=torch.float).to(
            self.actor.device)
        self.actor.train()

        return mu_prime.cpu().detach().numpy()

    def choose_action_no_train(self, observation):
        self.actor.eval()
        observation = torch.tensor(observation,
                                   dtype=torch.float).to(self.actor.device)
        mu = self.actor.forward(observation).to(self.actor.device)

        return mu.cpu().detach().numpy()

    def remember(self, state, action, reward, new_state, done):
        self.memory.push(state, action, reward, new_state, done)

    def learn(self):
        if self.memory.idx_last < self.batch_size:
            # not enough data in replay buffer
            return

        # select random events
        state, action, reward, new_state, done = self.memory.sample_buffer(
            self.batch_size)

        reward = torch.tensor(reward, dtype=torch.float).to(self.critic.device)
        done = torch.tensor(done).to(self.critic.device)
        new_state = torch.tensor(new_state,
                                 dtype=torch.float).to(self.critic.device)
        action = torch.tensor(action, dtype=torch.float).to(self.critic.device)
        state = torch.tensor(state, dtype=torch.float).to(self.critic.device)

        self.target_actor.eval()
        self.target_critic.eval()
        self.critic.eval()
        target_actions = self.target_actor.forward(new_state)
        critic_value_ = self.target_critic.forward(new_state, target_actions)
        critic_value = self.critic.forward(state, action)

        target = []
        for j in range(self.batch_size):
            target.append(reward[j] + self.gamma * critic_value_[j] * done[j])
        target = torch.tensor(target).to(self.critic.device)
        target = target.view(self.batch_size, 1)

        self.critic.train()
        self.critic.optimizer.zero_grad()
        critic_loss = F.mse_loss(target, critic_value)
        critic_loss.backward()
        self.critic.optimizer.step()

        self.critic.eval()
        self.actor.optimizer.zero_grad()
        mu = self.actor.forward(state)
        self.actor.train()
        actor_loss = -self.critic.forward(state, mu)
        actor_loss = torch.mean(actor_loss)
        actor_loss.backward()
        self.actor.optimizer.step()

        self.update_network_parameters()

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        actor_params = self.actor.named_parameters()
        critic_params = self.critic.named_parameters()
        target_actor_params = self.target_actor.named_parameters()
        target_critic_params = self.target_critic.named_parameters()

        critic_state_dict = dict(critic_params)
        actor_state_dict = dict(actor_params)
        target_critic_dict = dict(target_critic_params)
        target_actor_dict = dict(target_actor_params)

        for name in critic_state_dict:
            critic_state_dict[name] = tau*critic_state_dict[name].clone() + \
                                      (1-tau)*target_critic_dict[name].clone()

        self.target_critic.load_state_dict(critic_state_dict)

        for name in actor_state_dict:
            actor_state_dict[name] = tau*actor_state_dict[name].clone() + \
                                      (1-tau)*target_actor_dict[name].clone()
        self.target_actor.load_state_dict(actor_state_dict)

    def save_models(self):
        timestamp = time.strftime("%Y%m%d-%H%M%S")

        self.actor.save("actor_" + timestamp)
        self.target_actor.save("target_actor_" + timestamp)
        self.critic.save("critic_" + timestamp)
        self.target_critic.save("target_critic_" + timestamp)

    def load_models(self, fn_actor, fn_target_actor, fn_critic,
                    fn_target_critic):
        self.actor.load_checkpoint(fn_actor)
        self.target_actor.load_checkpoint(fn_target_actor)
        self.critic.load_checkpoint(fn_critic)
        self.target_critic.load_checkpoint(fn_target_critic)
コード例 #4
0
class DDPGAgent:
    def __init__(self,
                 state_space_dim,
                 action_space_dim,
                 min_action_val,
                 max_action_val,
                 hidden_layer_size=512,
                 gamma=0.99,
                 tau=0.0001,
                 path_to_load=None):
        self.gamma = gamma
        self.tau = tau
        self.min_action_val = min_action_val
        self.max_action_val = max_action_val
        self.buffer = Buffer(state_space_dim, action_space_dim)
        self.noise_generator = GaussianNoise(0., 0.2, action_space_dim)

        self.actor = Actor(state_space_dim, action_space_dim, max_action_val,
                           hidden_layer_size)
        self.critic = Critic(state_space_dim, action_space_dim,
                             hidden_layer_size)

        if path_to_load is not None:
            if os.path.exists(path_to_load + "_actor.h5") and \
                    os.path.exists(path_to_load + "_critic.h5"):
                self.load(path_to_load)

        self.target_actor = Actor(state_space_dim, action_space_dim,
                                  max_action_val, hidden_layer_size)
        self.target_critic = Critic(state_space_dim, action_space_dim,
                                    hidden_layer_size)

        self.target_actor.model.set_weights(self.actor.model.get_weights())
        self.target_critic.model.set_weights(self.critic.model.get_weights())

        critic_lr = 0.002
        actor_lr = 0.001

        self.critic_optimizer = tf.keras.optimizers.Adam(critic_lr)
        self.actor_optimizer = tf.keras.optimizers.Adam(actor_lr)

    @tf.function
    def _apply_gradients(self, states, actions, next_states, rewards):
        with tf.GradientTape() as tape:
            target_actions = self.target_actor.forward(next_states)
            y = tf.cast(rewards,
                        tf.float32) + self.gamma * self.target_critic.forward(
                            [next_states, target_actions])
            critic_value = self.critic.forward([states, actions])
            critic_loss = tf.math.reduce_mean(tf.math.square(y - critic_value))

        critic_grad = tape.gradient(critic_loss,
                                    self.critic.model.trainable_variables)
        self.critic_optimizer.apply_gradients(
            zip(critic_grad, self.critic.model.trainable_variables))

        with tf.GradientTape() as tape:
            actions = self.actor.forward(states)
            critic_value = self.critic.forward([states, actions])
            actor_loss = -tf.math.reduce_mean(critic_value)

        actor_grad = tape.gradient(actor_loss,
                                   self.actor.model.trainable_variables)
        self.actor_optimizer.apply_gradients(
            zip(actor_grad, self.actor.model.trainable_variables))

    def learn(self):
        states, actions, next_states, rewards = self.buffer.sample()
        self._apply_gradients(states, actions, next_states, rewards)

    def remember_step(self, info):
        self.buffer.remember(info)

    def update_targets(self):
        new_weights = []
        target_variables = self.target_critic.model.weights
        for i, variable in enumerate(self.critic.model.weights):
            new_weights.append(variable * self.tau + target_variables[i] *
                               (1 - self.tau))

        self.target_critic.model.set_weights(new_weights)

        new_weights = []
        target_variables = self.target_actor.model.weights
        for i, variable in enumerate(self.actor.model.weights):
            new_weights.append(variable * self.tau + target_variables[i] *
                               (1 - self.tau))

        self.target_actor.model.set_weights(new_weights)

    def get_best_action(self, state):
        tf_state = tf.expand_dims(tf.convert_to_tensor(state), 0)
        return tf.squeeze(self.actor.forward(tf_state)).numpy()

    def get_action(self, state):
        actions = self.get_best_action(
            state) + self.noise_generator.get_noise()
        return np.clip(actions, self.min_action_val, self.max_action_val)

    def save(self, path):
        print(f"Model has been saved as '{path}'")
        self.actor.save(path)
        self.critic.save(path)

    def load(self, path):
        print(f"Model has been loaded from '{path}'")
        self.actor.load(path)
        self.critic.load(path)
コード例 #5
0
ファイル: td3.py プロジェクト: TomoyaAkiyama/TD3
class TD3:
    def __init__(self,
                 device,
                 state_dim,
                 action_dim,
                 action_max,
                 gamma=0.99,
                 tau=0.005,
                 lr=3e-4,
                 policy_noise=0.2,
                 noise_clip=0.5,
                 exploration_noise=0.1,
                 policy_freq=2):

        self.actor = Actor(state_dim, 256, action_dim, action_max).to(device)
        self.target_actor = copy.deepcopy(self.actor)
        self.actor_optimizer = optim.Adam(params=self.actor.parameters(),
                                          lr=lr)
        self.critic = Critic(state_dim, 256, action_dim).to(device)
        self.target_critic = copy.deepcopy(self.critic)
        self.critic_optimizer = optim.Adam(params=self.critic.parameters(),
                                           lr=lr)

        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.rollout_actor = TD3RolloutActor(state_dim, action_dim, action_max,
                                             exploration_noise)
        self.sync_rollout_actor()

        self.iteration_num = 0

    def train(self, replay_buffer, batch_size=256):
        self.iteration_num += 1

        st, nx_st, ac, rw, mask = replay_buffer.sample(batch_size)
        with torch.no_grad():
            noise = (torch.randn_like(ac) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip)
            nx_ac = self.target_actor.forward(nx_st, noise)

            target_q1, target_q2 = self.target_critic.forward(nx_st, nx_ac)
            min_q = torch.min(target_q1, target_q2)
            target_q = rw + mask * self.gamma * min_q

        q1, q2 = self.critic.forward(st, ac)
        critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)
        self.critic.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if self.iteration_num % self.policy_freq == 0:
            actor_loss = -self.critic.q1(st, self.actor.forward(st)).mean()
            self.actor.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            for param, target_param in zip(self.critic.parameters(),
                                           self.target_critic.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           self.target_actor.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

        self.sync_rollout_actor()

    def sync_rollout_actor(self):
        for param, target_param in zip(self.actor.parameters(),
                                       self.rollout_actor.parameters()):
            target_param.data.copy_(param.data.cpu())

    def save(self, path):
        torch.save(self.critic.state_dict(), os.path.join(path, 'critic.pth'))
        torch.save(self.target_critic.state_dict(),
                   os.path.join(path, 'target_critic.pth'))
        torch.save(self.critic_optimizer.state_dict(),
                   os.path.join(path, 'critic_optimizer.pth'))

        torch.save(self.actor.state_dict(), os.path.join(path, 'actor.pth'))
        torch.save(self.target_actor.state_dict(),
                   os.path.join(path, 'target_actor.pth'))
        torch.save(self.actor_optimizer.state_dict(),
                   os.path.join(path, 'actor_optimizer.pth'))

    def load(self, path):
        self.critic.load_state_dict(
            torch.load(os.path.join(path, 'critic.pth')))
        self.target_critic.load_state_dict(
            torch.load(os.path.join(path, 'target_critic.pth')))
        self.critic_optimizer.load_state_dict(
            torch.load(os.path.join(path, 'critic_optimizer.pth')))

        self.actor.load_state_dict(torch.load(os.path.join(path, 'actor.pth')))
        self.target_actor.load_state_dict(
            torch.load(os.path.join(path, 'target_actor.pth')))
        self.actor_optimizer.load_state_dict(
            torch.load(os.path.join(path, 'actor_optimizer.pth')))
        self.sync_rollout_actor()