def test_factory_rollout_buffer_from_state_wrong_type(): # Assign buffer = RolloutBuffer(batch_size=5, buffer_size=20) state = buffer.get_state() state.type = "WrongType" # Act with pytest.raises(ValueError): BufferFactory.from_state(state=state)
def test_rollout_buffer_length(): # Assign buffer_size = 10 buffer = RolloutBuffer(batch_size=5, buffer_size=buffer_size) # Act for (state, action, reward, next_state, done) in generate_sample_SARS(buffer_size+1): buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) # Assert assert len(buffer) == buffer_size
def test_rollout_buffer_get_state_without_data(): # Assign buffer = RolloutBuffer(batch_size=5, buffer_size=20) # Act state = buffer.get_state() # Assert assert state.type == RolloutBuffer.type assert state.buffer_size == 20 assert state.batch_size == 5 assert state.data is None
def test_factory_rollout_buffer_from_state_without_data(): # Assign buffer = RolloutBuffer(batch_size=5, buffer_size=20) state = buffer.get_state() # Act new_buffer = BufferFactory.from_state(state=state) # Assert assert new_buffer == buffer assert new_buffer.buffer_size == state.buffer_size assert new_buffer.batch_size == state.batch_size assert len(new_buffer.data) == 0
def test_rollout_buffer_from_state_with_data(): # Assign buffer = RolloutBuffer(batch_size=5, buffer_size=20) buffer = populate_buffer(buffer, 30) state = buffer.get_state() # Act new_buffer = RolloutBuffer.from_state(state=state) # Assert assert new_buffer == buffer assert new_buffer.buffer_size == state.buffer_size assert new_buffer.batch_size == state.batch_size assert new_buffer.data == state.data assert len(buffer.data) == state.buffer_size
def test_rollout_buffer_sample_batch_equal_buffer(): # Assign buffer_size = batch_size = 20 buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size) # Act for (state, action, reward, next_state, done) in generate_sample_SARS(buffer_size+1): buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) # Assert num_samples = 0 for samples in buffer.sample(): num_samples += 1 for value in samples.values(): assert len(value) == batch_size assert num_samples == 1
def test_rollout_buffer_get_state_with_data(): # Assign buffer = RolloutBuffer(batch_size=5, buffer_size=20) sample_experience = Experience(state=[0, 1], action=[0], reward=0) for _ in range(25): buffer.add(**sample_experience.data) # Act state = buffer.get_state() # Assert assert state.type == RolloutBuffer.type assert state.buffer_size == 20 assert state.batch_size == 5 assert state.data is not None assert len(state.data) == 20 for data in state.data: assert data == sample_experience
def from_state(state: BufferState) -> BufferBase: if state.type == ReplayBuffer.type: return ReplayBuffer.from_state(state) elif state.type == PERBuffer.type: return PERBuffer.from_state(state) elif state.type == NStepBuffer.type: return NStepBuffer.from_state(state) elif state.type == RolloutBuffer.type: return RolloutBuffer.from_state(state) else: raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
def test_rollout_buffer_travers_buffer_twice(): # Assign batch_size = 10 buffer_size = 30 buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size) # Act reward = -1 for (state, action, _, next_state, done) in generate_sample_SARS(buffer_size): reward += 1 buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) # Assert num_samples = 0 # First pass for idx, samples in enumerate(buffer.sample()): num_samples += 1 rewards = samples['reward'] assert rewards == list(range(idx*10, (idx+1)*10)) # Second pass for idx, samples in enumerate(buffer.sample()): num_samples += 1 rewards = samples['reward'] assert rewards == list(range(idx*10, (idx+1)*10)) assert num_samples == 6 # 2 * ceil(buffer_size / batch_size)
def test_rollout_buffer_size_not_multiple_of_minibatch(): # Assign batch_size = 10 buffer_size = 55 buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size) # Act reward = -1 for (state, action, _, next_state, done) in generate_sample_SARS(buffer_size): reward += 1 buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) # Assert num_samples = 0 for idx, samples in enumerate(buffer.sample()): num_samples += 1 rewards = samples['reward'] if idx != 5: assert len(rewards) == batch_size assert rewards == list(range(idx*10, (idx+1)*10)) else: assert len(rewards) == 5 assert rewards == [50, 51, 52, 53, 54] assert num_samples == 6 # ceil(buffer_size / batch_size)
def test_rollout_buffer_clear_buffer(): # Assign batch_size = 10 buffer_size = 30 buffer = RolloutBuffer(batch_size=batch_size, buffer_size=buffer_size) # Act reward = -1 for (state, action, _, next_state, done) in generate_sample_SARS(buffer_size): reward += 1 buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) # Assert for idx, samples in enumerate(buffer.sample()): rewards = samples['reward'] assert rewards == list(range(idx*10, (idx+1)*10)) buffer.clear() assert len(buffer) == 0
def __init__(self, state_size: int, action_size: int, **kwargs): """ Parameters: state_size: Number of input dimensions. action_size: Number of output dimensions hidden_layers: (default: (100, 100) ) Tuple defining hidden dimensions in fully connected nets. is_discrete: (default: False) Whether return discrete action. kl_div: (default: False) Whether to use KL divergence in loss. using_gae: (default: True) Whether to use General Advantage Estimator. gae_lambda: (default: 0.96) Value of \lambda in GAE. actor_lr: (default: 0.0003) Learning rate for the actor (policy). critic_lr: (default: 0.001) Learning rate for the critic (value function). actor_betas: (default: (0.9, 0.999) Adam's betas for actor optimizer. critic_betas: (default: (0.9, 0.999) Adam's betas for critic optimizer. gamma: (default: 0.99) Discount value. ppo_ratio_clip: (default: 0.25) Policy ratio clipping value. num_epochs: (default: 1) Number of time to learn from samples. rollout_length: (default: 48) Number of actions to take before update. batch_size: (default: rollout_length) Number of samples used in learning. actor_number_updates: (default: 10) Number of times policy losses are propagated. critic_number_updates: (default: 10) Number of times value losses are propagated. entropy_weight: (default: 0.005) Weight of the entropy term in the loss. value_loss_weight: (default: 0.005) Weight of the entropy term in the loss. """ super().__init__(**kwargs) self.device = self._register_param( kwargs, "device", DEVICE) # Default device is CUDA if available self.state_size = state_size self.action_size = action_size self.hidden_layers = to_numbers_seq( self._register_param(kwargs, "hidden_layers", (100, 100))) self.iteration = 0 self.is_discrete = bool( self._register_param(kwargs, "is_discrete", False)) self.using_gae = bool(self._register_param(kwargs, "using_gae", True)) self.gae_lambda = float( self._register_param(kwargs, "gae_lambda", 0.96)) self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.actor_betas: Tuple[float, float] = to_numbers_seq( self._register_param(kwargs, 'actor_betas', (0.9, 0.999))) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 1e-3)) self.critic_betas: Tuple[float, float] = to_numbers_seq( self._register_param(kwargs, 'critic_betas', (0.9, 0.999))) self.gamma = float(self._register_param(kwargs, "gamma", 0.99)) self.ppo_ratio_clip = float( self._register_param(kwargs, "ppo_ratio_clip", 0.25)) self.using_kl_div = bool( self._register_param(kwargs, "using_kl_div", False)) self.kl_beta = float(self._register_param(kwargs, 'kl_beta', 0.1)) self.target_kl = float(self._register_param(kwargs, "target_kl", 0.01)) self.kl_div = float('inf') self.num_workers = int(self._register_param(kwargs, "num_workers", 1)) self.num_epochs = int(self._register_param(kwargs, "num_epochs", 1)) self.rollout_length = int( self._register_param(kwargs, "rollout_length", 48)) # "Much less than the episode length" self.batch_size = int( self._register_param(kwargs, "batch_size", self.rollout_length)) self.actor_number_updates = int( self._register_param(kwargs, "actor_number_updates", 10)) self.critic_number_updates = int( self._register_param(kwargs, "critic_number_updates", 10)) self.entropy_weight = float( self._register_param(kwargs, "entropy_weight", 0.5)) self.value_loss_weight = float( self._register_param(kwargs, "value_loss_weight", 1.0)) self.local_memory_buffer = {} self.action_scale = float( self._register_param(kwargs, "action_scale", 1)) self.action_min = float(self._register_param(kwargs, "action_min", -1)) self.action_max = float(self._register_param(kwargs, "action_max", 1)) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 100.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 100.0)) if kwargs.get("simple_policy", False): self.policy = MultivariateGaussianPolicySimple( self.action_size, **kwargs) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) self.buffer = RolloutBuffer(batch_size=self.batch_size, buffer_size=self.rollout_length) self.actor = ActorBody(state_size, self.policy.param_dim * action_size, gate_out=torch.tanh, hidden_layers=self.hidden_layers, device=self.device) self.critic = ActorBody(state_size, 1, gate_out=None, hidden_layers=self.hidden_layers, device=self.device) self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.critic_params = list(self.critic.parameters()) self.actor_opt = optim.Adam(self.actor_params, lr=self.actor_lr, betas=self.actor_betas) self.critic_opt = optim.Adam(self.critic_params, lr=self.critic_lr, betas=self.critic_betas) self._loss_actor = float('nan') self._loss_critic = float('nan') self._metrics: Dict[str, float] = {}
class PPOAgent(AgentBase): """ Proximal Policy Optimization (PPO) [1] is an online policy gradient method that could be considered as an implementation-wise simplified version of the Trust Region Policy Optimization (TRPO). [1] "Proximal Policy Optimization Algorithms" (2017) by J. Schulman, F. Wolski, P. Dhariwal, A. Radford, O. Klimov. https://arxiv.org/abs/1707.06347 """ name = "PPO" def __init__(self, state_size: int, action_size: int, **kwargs): """ Parameters: state_size: Number of input dimensions. action_size: Number of output dimensions hidden_layers: (default: (100, 100) ) Tuple defining hidden dimensions in fully connected nets. is_discrete: (default: False) Whether return discrete action. kl_div: (default: False) Whether to use KL divergence in loss. using_gae: (default: True) Whether to use General Advantage Estimator. gae_lambda: (default: 0.96) Value of \lambda in GAE. actor_lr: (default: 0.0003) Learning rate for the actor (policy). critic_lr: (default: 0.001) Learning rate for the critic (value function). actor_betas: (default: (0.9, 0.999) Adam's betas for actor optimizer. critic_betas: (default: (0.9, 0.999) Adam's betas for critic optimizer. gamma: (default: 0.99) Discount value. ppo_ratio_clip: (default: 0.25) Policy ratio clipping value. num_epochs: (default: 1) Number of time to learn from samples. rollout_length: (default: 48) Number of actions to take before update. batch_size: (default: rollout_length) Number of samples used in learning. actor_number_updates: (default: 10) Number of times policy losses are propagated. critic_number_updates: (default: 10) Number of times value losses are propagated. entropy_weight: (default: 0.005) Weight of the entropy term in the loss. value_loss_weight: (default: 0.005) Weight of the entropy term in the loss. """ super().__init__(**kwargs) self.device = self._register_param( kwargs, "device", DEVICE) # Default device is CUDA if available self.state_size = state_size self.action_size = action_size self.hidden_layers = to_numbers_seq( self._register_param(kwargs, "hidden_layers", (100, 100))) self.iteration = 0 self.is_discrete = bool( self._register_param(kwargs, "is_discrete", False)) self.using_gae = bool(self._register_param(kwargs, "using_gae", True)) self.gae_lambda = float( self._register_param(kwargs, "gae_lambda", 0.96)) self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.actor_betas: Tuple[float, float] = to_numbers_seq( self._register_param(kwargs, 'actor_betas', (0.9, 0.999))) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 1e-3)) self.critic_betas: Tuple[float, float] = to_numbers_seq( self._register_param(kwargs, 'critic_betas', (0.9, 0.999))) self.gamma = float(self._register_param(kwargs, "gamma", 0.99)) self.ppo_ratio_clip = float( self._register_param(kwargs, "ppo_ratio_clip", 0.25)) self.using_kl_div = bool( self._register_param(kwargs, "using_kl_div", False)) self.kl_beta = float(self._register_param(kwargs, 'kl_beta', 0.1)) self.target_kl = float(self._register_param(kwargs, "target_kl", 0.01)) self.kl_div = float('inf') self.num_workers = int(self._register_param(kwargs, "num_workers", 1)) self.num_epochs = int(self._register_param(kwargs, "num_epochs", 1)) self.rollout_length = int( self._register_param(kwargs, "rollout_length", 48)) # "Much less than the episode length" self.batch_size = int( self._register_param(kwargs, "batch_size", self.rollout_length)) self.actor_number_updates = int( self._register_param(kwargs, "actor_number_updates", 10)) self.critic_number_updates = int( self._register_param(kwargs, "critic_number_updates", 10)) self.entropy_weight = float( self._register_param(kwargs, "entropy_weight", 0.5)) self.value_loss_weight = float( self._register_param(kwargs, "value_loss_weight", 1.0)) self.local_memory_buffer = {} self.action_scale = float( self._register_param(kwargs, "action_scale", 1)) self.action_min = float(self._register_param(kwargs, "action_min", -1)) self.action_max = float(self._register_param(kwargs, "action_max", 1)) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 100.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 100.0)) if kwargs.get("simple_policy", False): self.policy = MultivariateGaussianPolicySimple( self.action_size, **kwargs) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) self.buffer = RolloutBuffer(batch_size=self.batch_size, buffer_size=self.rollout_length) self.actor = ActorBody(state_size, self.policy.param_dim * action_size, gate_out=torch.tanh, hidden_layers=self.hidden_layers, device=self.device) self.critic = ActorBody(state_size, 1, gate_out=None, hidden_layers=self.hidden_layers, device=self.device) self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.critic_params = list(self.critic.parameters()) self.actor_opt = optim.Adam(self.actor_params, lr=self.actor_lr, betas=self.actor_betas) self.critic_opt = optim.Adam(self.critic_params, lr=self.critic_lr, betas=self.critic_betas) self._loss_actor = float('nan') self._loss_critic = float('nan') self._metrics: Dict[str, float] = {} @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value def __eq__(self, o: object) -> bool: return super().__eq__(o) \ and self._config == o._config \ and self.buffer == o.buffer \ and self.get_network_state() == o.get_network_state() # TODO @dawid: Currently net isn't compared properly def __clear_memory(self): self.buffer.clear() @torch.no_grad() def act(self, state, epsilon: float = 0.): actions = [] logprobs = [] values = [] state = to_tensor(state).view(self.num_workers, self.state_size).float().to(self.device) for worker in range(self.num_workers): actor_est = self.actor.act(state[worker].unsqueeze(0)) assert not torch.any(torch.isnan(actor_est)) dist = self.policy(actor_est) action = dist.sample() value = self.critic.act( state[worker].unsqueeze(0)) # Shape: (1, 1) logprob = self.policy.log_prob(dist, action) # Shape: (1,) values.append(value) logprobs.append(logprob) if self.is_discrete: # *Technically* it's the max of Softmax but that's monotonic. action = int(torch.argmax(action)) else: action = torch.clamp(action * self.action_scale, self.action_min, self.action_max) action = action.cpu().numpy().flatten().tolist() actions.append(action) self.local_memory_buffer['value'] = torch.cat(values) self.local_memory_buffer['logprob'] = torch.stack(logprobs) assert len(actions) == self.num_workers return actions if self.num_workers > 1 else actions[0] def step(self, state, action, reward, next_state, done, **kwargs): self.iteration += 1 self.buffer.add( state=torch.tensor(state).reshape(self.num_workers, self.state_size).float(), action=torch.tensor(action).reshape(self.num_workers, self.action_size).float(), reward=torch.tensor(reward).reshape(self.num_workers, 1), done=torch.tensor(done).reshape(self.num_workers, 1), logprob=self.local_memory_buffer['logprob'].reshape( self.num_workers, 1), value=self.local_memory_buffer['value'].reshape( self.num_workers, 1), ) if self.iteration % self.rollout_length == 0: self.train() self.__clear_memory() def train(self): """ Main loop that initiates the training. """ experiences = self.buffer.all_samples() rewards = to_tensor(experiences['reward']).to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).to(self.device) actions = to_tensor(experiences['action']).to(self.device) values = to_tensor(experiences['value']).to(self.device) logprobs = to_tensor(experiences['logprob']).to(self.device) assert rewards.shape == dones.shape == values.shape == logprobs.shape assert states.shape == ( self.rollout_length, self.num_workers, self.state_size), f"Wrong states shape: {states.shape}" assert actions.shape == ( self.rollout_length, self.num_workers, self.action_size), f"Wrong action shape: {actions.shape}" with torch.no_grad(): if self.using_gae: next_value = self.critic.act(states[-1]) advantages = compute_gae(rewards, dones, values, next_value, self.gamma, self.gae_lambda) advantages = normalize(advantages) returns = advantages + values # returns = normalize(advantages + values) assert advantages.shape == returns.shape == values.shape else: returns = revert_norm_returns(rewards, dones, self.gamma) returns = returns.float() advantages = normalize(returns - values) assert advantages.shape == returns.shape == values.shape for _ in range(self.num_epochs): idx = 0 self.kl_div = 0 while idx < self.rollout_length: _states = states[idx:idx + self.batch_size].view( -1, self.state_size).detach() _actions = actions[idx:idx + self.batch_size].view( -1, self.action_size).detach() _logprobs = logprobs[idx:idx + self.batch_size].view( -1, 1).detach() _returns = returns[idx:idx + self.batch_size].view(-1, 1).detach() _advantages = advantages[idx:idx + self.batch_size].view( -1, 1).detach() idx += self.batch_size self.learn( (_states, _actions, _logprobs, _returns, _advantages)) self.kl_div = abs( self.kl_div) / (self.actor_number_updates * self.rollout_length / self.batch_size) if self.kl_div > self.target_kl * 1.75: self.kl_beta = min(2 * self.kl_beta, 1e2) # Max 100 if self.kl_div < self.target_kl / 1.75: self.kl_beta = max(0.5 * self.kl_beta, 1e-6) # Min 0.000001 self._metrics['policy/kl_beta'] = self.kl_beta def compute_policy_loss(self, samples): states, actions, old_log_probs, _, advantages = samples actor_est = self.actor(states) dist = self.policy(actor_est) entropy = dist.entropy() new_log_probs = self.policy.log_prob(dist, actions).view(-1, 1) assert new_log_probs.shape == old_log_probs.shape r_theta = (new_log_probs - old_log_probs).exp() r_theta_clip = torch.clamp(r_theta, 1.0 - self.ppo_ratio_clip, 1.0 + self.ppo_ratio_clip) assert r_theta.shape == r_theta_clip.shape # KL = E[log(P/Q)] = sum_{P}( P * log(P/Q) ) -- \approx --> avg_{P}( log(P) - log(Q) ) approx_kl_div = (old_log_probs - new_log_probs).mean().item() if self.using_kl_div: # Ratio threshold for updates is 1.75 (although it should be configurable) policy_loss = -torch.mean( r_theta * advantages) + self.kl_beta * approx_kl_div else: joint_theta_adv = torch.stack( (r_theta * advantages, r_theta_clip * advantages)) assert joint_theta_adv.shape[0] == 2 policy_loss = -torch.amin(joint_theta_adv, dim=0).mean() entropy_loss = -self.entropy_weight * entropy.mean() loss = policy_loss + entropy_loss self._metrics['policy/kl_div'] = approx_kl_div self._metrics['policy/policy_ratio'] = float(r_theta.mean()) self._metrics['policy/policy_ratio_clip_mean'] = float( r_theta_clip.mean()) return loss, approx_kl_div def compute_value_loss(self, samples): states, _, _, returns, _ = samples values = self.critic(states) self._metrics['value/value_mean'] = values.mean() self._metrics['value/value_std'] = values.std() return F.mse_loss(values, returns) def learn(self, samples): self._loss_actor = 0. for _ in range(self.actor_number_updates): self.actor_opt.zero_grad() loss_actor, kl_div = self.compute_policy_loss(samples) self.kl_div += kl_div if kl_div > 1.5 * self.target_kl: # Early break # print(f"Iter: {i:02} Early break") break loss_actor.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_actor) self.actor_opt.step() self._loss_actor = loss_actor.item() for _ in range(self.critic_number_updates): self.critic_opt.zero_grad() loss_critic = self.compute_value_loss(samples) loss_critic.backward() nn.utils.clip_grad_norm_(self.critic_params, self.max_grad_norm_critic) self.critic_opt.step() self._loss_critic = float(loss_critic.item()) def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) for metric_name, metric_value in self._metrics.items(): data_logger.log_value(metric_name, metric_value, step) policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) if full_log: for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"actor/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"actor/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.critic.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic/layer_bias_{idx}", layer.bias, step) def get_state(self) -> AgentState: return AgentState(model=self.name, state_space=self.state_size, action_space=self.action_size, config=self._config, buffer=copy.deepcopy(self.buffer.get_state()), network=copy.deepcopy(self.get_network_state())) def get_network_state(self) -> NetworkState: return NetworkState(net=dict( policy=self.policy.state_dict(), actor=self.actor.state_dict(), critic=self.critic.state_dict(), )) def set_buffer(self, buffer_state: BufferState) -> None: self.buffer = BufferFactory.from_state(buffer_state) def set_network(self, network_state: NetworkState) -> None: self.policy.load_state_dict(network_state.net['policy']) self.actor.load_state_dict(network_state.net['actor']) self.critic.load_state_dict(network_state.net['critic']) @staticmethod def from_state(state: AgentState) -> AgentBase: agent = PPOAgent(state.state_space, state.action_space, **state.config) if state.network is not None: agent.set_network(state.network) if state.buffer is not None: agent.set_buffer(state.buffer) return agent def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.policy.load_state_dict(agent_state['policy']) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic'])