class SACAgent: def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.env = env self.action_range = [env.action_space.low, env.action_space.high] self.obs_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] # hyperparameters self.gamma = gamma self.tau = tau self.update_step = 0 self.delay_step = 2 # initialize networks self.value_net = ValueNetwork(self.obs_dim, 1).to(self.device) self.target_value_net = ValueNetwork(self.obs_dim, 1).to(self.device) self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.policy_net = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device) # copy params to target param for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(param) # initialize optimizers self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr) self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr) self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) self.replay_buffer = BasicBuffer(buffer_maxlen) def get_action(self, state): state = torch.FloatTensor(state).unsqueeze(0).to(self.device) mean, log_std = self.policy_net.forward(state) std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.cpu().detach().squeeze(0).numpy() return self.rescale_action(action) def rescale_action(self, action): return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\ (self.action_range[1] + self.action_range[0]) / 2.0 def update(self, batch_size): states, actions, rewards, next_states, dones = self.replay_buffer.sample( batch_size) states = torch.FloatTensor(states).to(self.device) actions = torch.FloatTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).to(self.device) dones = dones.view(dones.size(0), -1) next_actions, next_log_pi = self.policy_net.sample(next_states) next_q1 = self.q_net1(next_states, next_actions) next_q2 = self.q_net2(next_states, next_actions) next_v = self.target_value_net(next_states) # value Loss next_v_target = torch.min(next_q1, next_q2) - next_log_pi curr_v = self.value_net.forward(states) v_loss = F.mse_loss(curr_v, next_v_target.detach()) # q loss curr_q1 = self.q_net1.forward(states, actions) curr_q2 = self.q_net2.forward(states, actions) expected_q = rewards + (1 - dones) * self.gamma * next_v q1_loss = F.mse_loss(curr_q1, expected_q.detach()) q2_loss = F.mse_loss(curr_q2, expected_q.detach()) # update value network and q networks self.value_optimizer.zero_grad() v_loss.backward() self.value_optimizer.step() self.q1_optimizer.zero_grad() q1_loss.backward() self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() self.q2_optimizer.step() #delayed update for policy net and target value nets if self.update_step % self.delay_step == 0: new_actions, log_pi = self.policy_net.sample(states) min_q = torch.min(self.q_net1.forward(states, new_actions), self.q_net2.forward(states, new_actions)) policy_loss = (log_pi - min_q).mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # target networks for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) self.update_step += 1
class OldSACAgent: def __init__(self, env, render, config_info): self.env = env self.render = render self._reset_env() # Create run folder to store parameters, figures, and tensorboard logs self.path_runs = create_run_folder(config_info) # Extract training parameters from yaml config file param = load_training_parameters(config_info["config_param"]) self.train_param = param["training"] # Define device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device in use : {self.device}") # Define state and action dimension spaces state_dim = env.observation_space.shape[0] num_actions = env.action_space.shape[0] # Define models hidden_size = param["model"]["hidden_size"] self.q_net = QNetwork(state_dim, num_actions, hidden_size).to(self.device) self.v_net = VNetwork(state_dim, hidden_size).to(self.device) self.target_v_net = VNetwork(state_dim, hidden_size).to(self.device) self.target_v_net.load_state_dict(self.v_net.state_dict()) self.policy_net = PolicyNetwork(state_dim, num_actions, hidden_size).to( self.device ) # Define loss criterion self.q_criterion = nn.MSELoss() self.v_criterion = nn.MSELoss() # Define optimizers lr = float(param["optimizer"]["learning_rate"]) self.q_opt = optim.Adam(self.q_net.parameters(), lr=lr) self.v_opt = optim.Adam(self.v_net.parameters(), lr=lr) self.policy_opt = optim.Adam(self.policy_net.parameters(), lr=lr) # Initialize replay buffer self.replay_buffer = ReplayBuffer(param["training"]["replay_size"]) self.transition = namedtuple( "transition", field_names=["state", "action", "reward", "done", "next_state"], ) # Useful variables self.batch_size = param["training"]["batch_size"] self.gamma = param["training"]["gamma"] self.tau = param["training"]["tau"] self.start_step = param["training"]["start_step"] self.max_timesteps = param["training"]["max_timesteps"] self.alpha = param["training"]["alpha"] def _reset_env(self): # Reset the environment and initialize episode reward self.state, self.done = self.env.reset(), False self.episode_reward = 0.0 self.episode_step = 0 def train(self): # Main training loop total_timestep = 0 all_episode_rewards = [] all_mean_rewards = [] update = 0 # Create tensorboard writer writer = SummaryWriter(log_dir=self.path_runs, comment="-sac") for episode in itertools.count(1, 1): self._reset_env() while not self.done: # trick to improve exploration at the start of training if self.start_step > total_timestep: action = self.env.action_space.sample() # Sample random action else: action = self.policy_net.get_action( self.state, self.device ) # Sample action from policy # Fill the replay buffer up with transitions if len(self.replay_buffer) > self.batch_size: batch = self.replay_buffer.sample_buffer(self.batch_size) # Update parameters of all the networks q_loss, v_loss, policy_loss = self.train_on_batch(batch) writer.add_scalar("loss/q", q_loss, update) writer.add_scalar("loss/v", v_loss, update) writer.add_scalar("loss/policy", policy_loss, update) update += 1 if self.render: self.env.render() # Perform one step in the environment next_state, reward, self.done, _ = self.env.step(action) total_timestep += 1 self.episode_step += 1 self.episode_reward += reward # Create a tuple for the new transition new_transition = self.transition( self.state, action, reward, self.done, next_state ) # Append transition to the replay buffer self.replay_buffer.store_transition(new_transition) self.state = next_state if total_timestep > self.max_timesteps: break mean_reward = np.mean(all_episode_rewards[-100:]) all_episode_rewards.append(self.episode_reward) all_mean_rewards.append(mean_reward) print( "Episode n°{} ; total timestep [{}/{}] ; episode steps {} ; " "reward {} ; mean reward {}".format( episode, total_timestep, self.max_timesteps, self.episode_step, round(self.episode_reward, 2), round(mean_reward, 2), ) ) writer.add_scalar("reward", self.episode_reward, episode) writer.add_scalar("mean reward", mean_reward, episode) # Save networks' weights path_critic = os.path.join(self.path_runs, "critic.pth") path_actor = os.path.join(self.path_runs, "actor.pth") torch.save(self.q_net.state_dict(), path_critic) torch.save(self.policy_net.state_dict(), path_actor) # Plot reward self.plot_reward(all_episode_rewards, all_mean_rewards) # Close all writer.close() self.env.close() def train_on_batch(self, batch_samples): # Unpack batch_size of transitions randomly drawn from the replay buffer ( state_batch, action_batch, reward_batch, done_int_batch, next_state_batch, ) = batch_samples # Transform np arrays into tensors and send them to device state_batch = torch.tensor(state_batch).to(self.device) next_state_batch = torch.tensor(next_state_batch).to(self.device) action_batch = torch.tensor(action_batch).to(self.device) reward_batch = torch.tensor(reward_batch).unsqueeze(1).to(self.device) done_int_batch = torch.tensor(done_int_batch).unsqueeze(1).to(self.device) q_value, _ = self.q_net(state_batch, action_batch) value = self.v_net(state_batch) pi, log_pi = self.policy_net.sample(state_batch) ### Update Q target_next_value = self.target_v_net(next_state_batch) next_q_value = ( reward_batch + (1 - done_int_batch) * self.gamma * target_next_value ) q_loss = self.q_criterion(q_value, next_q_value.detach()) ### Update V q_pi, _ = self.q_net(state_batch, pi) next_value = q_pi - log_pi v_loss = self.v_criterion(value, next_value.detach()) ### Update policy log_pi_target = q_pi - value policy_loss = (log_pi * (log_pi - log_pi_target).detach()).mean() # Losses and optimizers self.q_opt.zero_grad() q_loss.backward() self.q_opt.step() self.v_opt.zero_grad() v_loss.backward() self.v_opt.step() self.policy_opt.zero_grad() policy_loss.backward() self.policy_opt.step() soft_update(self.target_v_net, self.v_net, self.tau) return q_loss.item(), v_loss.item(), policy_loss.item() def plot_reward(self, data, mean_data): plt.plot(data, label="reward") plt.plot(mean_data, label="mean reward") plt.xlabel("Episode") plt.ylabel("Reward") plt.title(f"Reward evolution for {self.env.unwrapped.spec.id} Gym environment") plt.tight_layout() plt.legend() path_fig = os.path.join(self.path_runs, "figure.png") plt.savefig(path_fig) print(f"Figure saved to {path_fig}") plt.show()
class SACAgent: def __init__(self, env, gamma, tau, alpha, q_lr, policy_lr, a_lr, buffer_maxlen): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.env = env self.action_range = [0, 250] self.obs_dim = env.state_dim self.action_dim = env.action_dim # hyperparameters self.gamma = gamma self.tau = tau self.update_step = 0 self.delay_step = 2 # initialize networks self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.target_q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.target_q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.policy_net = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device) # copy params to target param for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()): target_param.data.copy_(param) for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()): target_param.data.copy_(param) # initialize optimizers self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr) self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) # entropy temperature self.alpha = alpha self.target_entropy = -torch.prod( torch.Tensor([self.action_dim, 1]).to(self.device)).item() self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha_optim = optim.Adam([self.log_alpha], lr=a_lr) self.replay_buffer = BasicBuffer(buffer_maxlen) def get_action(self, state): state = torch.FloatTensor(state).unsqueeze(0).to(self.device) mean, log_std = self.policy_net.forward(state) std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.cpu().detach().squeeze(0).numpy() return self.rescale_action(action) def rescale_action(self, action): return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\ (self.action_range[1] + self.action_range[0]) / 2.0 def update(self, batch_size): states, actions, rewards, next_states, dones = self.replay_buffer.sample( batch_size) states = torch.FloatTensor(states).to(self.device) actions = torch.FloatTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).to(self.device) dones = dones.view(dones.size(0), -1) next_actions, next_log_pi = self.policy_net.sample(next_states) next_q1 = self.target_q_net1(next_states, next_actions) next_q2 = self.target_q_net2(next_states, next_actions) next_q_target = torch.min(next_q1, next_q2) - self.alpha * next_log_pi expected_q = rewards + (1 - dones) * self.gamma * next_q_target # q loss curr_q1 = self.q_net1.forward(states, actions) curr_q2 = self.q_net2.forward(states, actions) q1_loss = F.mse_loss(curr_q1, expected_q.detach()) q2_loss = F.mse_loss(curr_q2, expected_q.detach()) # update q networks self.q1_optimizer.zero_grad() q1_loss.backward() self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() self.q2_optimizer.step() # delayed update for policy network and target q networks new_actions, log_pi = self.policy_net.sample(states) if self.update_step % self.delay_step == 0: min_q = torch.min(self.q_net1.forward(states, new_actions), self.q_net2.forward(states, new_actions)) policy_loss = (self.alpha * log_pi - min_q).mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # target networks for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) # update temperature alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.alpha = self.log_alpha.exp() self.update_step += 1
class SACAgent(): def __init__(self, env: object, gamma: float, tau: float, buffer_maxlen: int, critic_lr: float, actor_lr: float, reward_scale: int): # Selecting the device to use, wheter CUDA (GPU) if available or CPU self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # Creating the Gym environments for training and evaluation self.env = env # Get max and min values of the action of this environment self.action_range = [ self.env.action_space.low, self.env.action_space.high ] # Get dimension of of the state and the action self.obs_dim = self.env.observation_space.shape[0] self.action_dim = self.env.action_space.shape[0] # hyperparameters self.gamma = gamma self.tau = tau self.critic_lr = critic_lr self.actor_lr = actor_lr self.buffer_maxlen = buffer_maxlen self.reward_scale = reward_scale # Scaling and bias factor for the actions -> We need scaling of the actions because each environment has different min and max values of actions self.scale = (self.action_range[1] - self.action_range[0]) / 2.0 self.bias = (self.action_range[1] + self.action_range[0]) / 2.0 # initialize networks self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.target_q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.target_q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device) self.policy = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device) # copy weight parameters to the target Q networks for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()): target_param.data.copy_(param) for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()): target_param.data.copy_(param) # initialize optimizers self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=self.critic_lr) self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=self.critic_lr) self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=self.actor_lr) # Create a replay buffer self.replay_buffer = BasicBuffer(self.buffer_maxlen) def update(self, batch_size: int): # Sampling experiences from the replay buffer states, actions, rewards, next_states, dones = self.replay_buffer.sample( batch_size) # Convert numpy arrays of experience tuples into pytorch tensors states = torch.FloatTensor(states).to(self.device) actions = torch.FloatTensor(actions).to(self.device) rewards = self.reward_scale * torch.FloatTensor(rewards).to( self.device) # in SAC we do reward scaling for the sampled rewards next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).to(self.device) dones = dones.view(dones.size(0), -1) # Critic update (computing the loss) # Please refer to equation (6) in the paper for details # Sample actions for the next states (s_t+1) using the current policy next_actions, next_log_pi, _, _ = self.policy.sample( next_states, self.scale) next_actions = self.rescale_action(next_actions) # Compute Q(s_t+1,a_t+1) by giving the states and actions to the Q network and choose the minimum from 2 target Q networks next_q1 = self.target_q_net1(next_states, next_actions) next_q2 = self.target_q_net2(next_states, next_actions) min_q = torch.min(next_q1, next_q2) # find minimum between next_q1 and next_q2 # Compute the next Q_target (Q(s_t,a_t)-alpha(next_log_pi)) next_q_target = (min_q - next_log_pi) # Compute the Q(s_t,a_t) using s_t and a_t from the replay buffer curr_q1 = self.q_net1.forward(states, actions) curr_q2 = self.q_net2.forward(states, actions) # Find expected Q, i.e., r(t) + gamma*next_q_target expected_q = rewards + (1 - dones) * self.gamma * next_q_target # Compute loss between Q network and expected Q q1_loss = F.mse_loss(curr_q1, expected_q.detach()) q2_loss = F.mse_loss(curr_q2, expected_q.detach()) # Backpropagate the losses and update Q network parameters self.q1_optimizer.zero_grad() q1_loss.backward() self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() self.q2_optimizer.step() # Policy update (computing the loss) # Sample new actions for the current states (s_t) using the current policy new_actions, log_pi, _, _ = self.policy.sample(states, self.scale) new_actions = self.rescale_action(new_actions) # Compute Q(s_t,a_t) and choose the minimum from 2 Q networks new_q1 = self.q_net1.forward(states, new_actions) new_q2 = self.q_net2.forward(states, new_actions) min_q = torch.min(new_q1, new_q2) # Compute the next policy loss, i.e., alpha*log_pi - Q(s_t,a_t) eq. (7) policy_loss = (log_pi - min_q).mean() # Backpropagate the losses and update policy network parameters self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # Updating target networks with soft update using update rate tau for target_param, param in zip(self.target_q_net1.parameters(), self.q_net1.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) for target_param, param in zip(self.target_q_net2.parameters(), self.q_net2.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) def get_action( self, state: np.ndarray, stochastic: bool) -> Tuple[np.ndarray, torch.Tensor, torch.Tensor]: # state: the state input to the pi network # stochastic: boolean (True -> use noisy action, False -> use noiseless (deterministic action)) state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # Get mean and sigma from the policy network mean, log_std = self.policy.forward(state) std = log_std.exp() # Stochastic mode is used for training, non-stochastic mode is used for evaluation if stochastic: normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.cpu().detach().squeeze(0).numpy() else: normal = Normal(mean, 0) z = normal.sample() action = torch.tanh(z) action = action.cpu().detach().squeeze(0).numpy() # return a rescaled action, and also the mean and standar deviation of the action # we use a rescaled action since the output of the policy network is [-1,1] and the mujoco environments could be ranging from [-n,n] where n is an arbitrary real value return self.rescale_action(action), mean, std def rescale_action(self, action: np.ndarray) -> np.ndarray: # we use a rescaled action since the output of the policy network is [-1,1] and the mujoco environments could be ranging from [-n,n] where n is an arbitrary real value # scale -> scalar multiplication # bias -> scalar offset return action * self.scale[0] + self.bias[0] def Actor_save(self, WORKSPACE: str): # save 각 node별 모델 저장 print("Save the torch model") savePath = WORKSPACE + "./policy_model5_Hop_.pth" torch.save(self.policy.state_dict(), savePath) def Actor_load(self, WORKSPACE: str): # save 각 node별 모델 로드 print("load the torch model") savePath = WORKSPACE + "./policy_model5_Hop_.pth" # Best self.policy = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device) self.policy.load_state_dict(torch.load(savePath))
class SACAgent: def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.firsttime = 0 self.env = env self.action_range = [env.action_space.low, env.action_space.high] #self.obs_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] #1 self.conv_channels = 4 self.kernel_size = (3, 3) self.img_size = (500, 500, 3) print("Diagnostics:") print(f"action_range: {self.action_range}") #print(f"obs_dim: {self.obs_dim}") print(f"action_dim: {self.action_dim}") # hyperparameters self.gamma = gamma self.tau = tau self.update_step = 0 self.delay_step = 2 # initialize networks self.feature_net = FeatureExtractor(self.img_size[2], self.conv_channels, self.kernel_size).to(self.device) print("Feature net init'd successfully") input_dim = self.feature_net.get_output_size(self.img_size) self.input_size = input_dim[0] * input_dim[1] * input_dim[2] print(f"input_size: {self.input_size}") self.value_net = ValueNetwork(self.input_size, 1).to(self.device) self.target_value_net = ValueNetwork(self.input_size, 1).to(self.device) self.q_net1 = SoftQNetwork(self.input_size, self.action_dim).to(self.device) self.q_net2 = SoftQNetwork(self.input_size, self.action_dim).to(self.device) self.policy_net = PolicyNetwork(self.input_size, self.action_dim).to(self.device) print("Finished initing all nets") # copy params to target param for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(param) print("Finished copying targets") # initialize optimizers self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr) self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr) self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) print("Finished initing optimizers") self.replay_buffer = BasicBuffer(buffer_maxlen) print("End of init") def get_action(self, state): if state.shape != self.img_size: print( f"Invalid size, expected shape {self.img_size}, got {state.shape}" ) return None inp = torch.from_numpy(state).float().permute(2, 0, 1).unsqueeze(0).to( self.device) features = self.feature_net(inp) features = features.view(-1, self.input_size) mean, log_std = self.policy_net.forward(features) std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.cpu().detach().squeeze(0).numpy() return self.rescale_action(action) def rescale_action(self, action): return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\ (self.action_range[1] + self.action_range[0]) / 2.0 def update(self, batch_size): states, actions, rewards, next_states, dones = self.replay_buffer.sample( batch_size) # states and next states are lists of ndarrays, np.stack converts them to # ndarrays of shape (batch_size, height, width, num_channels) states = np.stack(states) next_states = np.stack(next_states) states = torch.FloatTensor(states).permute(0, 3, 1, 2).to(self.device) actions = torch.FloatTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).to(self.device) next_states = torch.FloatTensor(next_states).permute(0, 3, 1, 2).to(self.device) dones = torch.FloatTensor(dones).to(self.device) dones = dones.view(dones.size(0), -1) # Process images features = self.feature_net( states) #.contiguous() # Properly shaped due to batching next_features = self.feature_net(next_states) #.contiguous() features = torch.reshape(features, (64, self.input_size)) next_features = torch.reshape(next_features, (64, self.input_size)) next_actions, next_log_pi = self.policy_net.sample(next_features) next_q1 = self.q_net1(next_features, next_actions) next_q2 = self.q_net2(next_features, next_actions) next_v = self.target_value_net(next_features) next_v_target = torch.min(next_q1, next_q2) - next_log_pi curr_v = self.value_net.forward(features) v_loss = F.mse_loss(curr_v, next_v_target.detach()) # q loss expected_q = rewards + (1 - dones) * self.gamma * next_v curr_q1 = self.q_net1.forward(features, actions) curr_q2 = self.q_net2.forward(features, actions) q1_loss = F.mse_loss(curr_q1, expected_q.detach()) q2_loss = F.mse_loss(curr_q2, expected_q.detach()) # update value and q networks self.value_optimizer.zero_grad() v_loss.backward(retain_graph=True) self.value_optimizer.step() self.q1_optimizer.zero_grad() q1_loss.backward(retain_graph=True) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward(retain_graph=True) self.q2_optimizer.step() # delayed update for policy network and target q networks if self.update_step % self.delay_step == 0: new_actions, log_pi = self.policy_net.sample(features) min_q = torch.min(self.q_net1.forward(features, new_actions), self.q_net2.forward(features, new_actions)) policy_loss = (log_pi - min_q).mean() self.policy_optimizer.zero_grad() policy_loss.backward(retain_graph=True) self.policy_optimizer.step() # target networks for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()): target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) self.update_step += 1