def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim).to(args.device) self.actor_target = Actor(state_dim, action_dim).to(args.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.max_action = 1 self.update_counter = 0 self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.policy_noise = args.policy_noise self.noise_clip = args.noise_clip self.policy_freq = args.policy_freq self.device = args.device self.actor_clip_gradient = args.actor_clip_gradient
def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.decoder = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), args.lr_critic) self.critic = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.critic_target = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(args.device) self.critic_target.load_state_dict(self.critic.state_dict()) self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.top_quantiles_to_drop = args.top_quantiles_to_drop_per_net * args.n_nets self.target_entropy = -np.prod(action_dim) self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=self.device) self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.lr_alpha) self.total_it = 0
def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.write_tensorboard = False self.top_quantiles_to_drop = args.top_quantiles_to_drop self.target_entropy = args.target_entropy self.quantiles_total = critic.n_quantiles * critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=args.device) self.total_it = 0 self.step = 0
def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim).to(args.device) self.actor_target = Actor(state_dim, action_dim).to(args.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.list_target_critic = [] for c in range(args.num_q_target): critic_target = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) critic_target.load_state_dict(critic_target.state_dict()) self.list_target_critic.append(critic_target) self.num_q_target = args.num_q_target self.max_action = 1 self.update_counter = 0 self.currentQNet = 0 self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.policy_noise = args.policy_noise self.noise_clip = args.noise_clip self.policy_freq = args.policy_freq self.device = args.device self.actor_clip_gradient = args.actor_clip_gradient self.write_tensorboard = False
class DDPG(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim).to(args.device) self.actor_target = Actor(state_dim, action_dim).to(args.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.max_action = 1 self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device def select_action(self, state): state = torch.Tensor(state).to(self.device).div_(255) state = state.unsqueeze(0) state = self.critic.create_vector(state) return self.actor(state).cpu().data.numpy().flatten() def train(self, replay_buffer, writer, iterations): self.step += 1 for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( self.batch_size) #batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(self.batch_size) #state_image = torch.Tensor(batch_states).to(self.device).div_(255) #next_state = torch.Tensor(batch_next_states).to(self.device).div_(255) # create vector #reward = torch.Tensor(batch_rewards).to(self.device) #done = torch.Tensor(batch_dones).to(self.device) obs = obs.div_(255) next_obs = next_obs.div_(255) obs_aug = obs_aug.div_(255) next_obs_aug = next_obs_aug.div_(255) state = self.critic.create_vector(obs) detach_state = state.detach() next_state = self.critic.create_vector(next_obs) state_aug = self.critic.create_vector(obs_aug) detach_state_aug = state_aug.detach() next_state_aug = self.critic.create_vector(next_obs_aug) with torch.no_grad(): # Step 5: From the next state s’, the Actor target plays the next action a’ next_action = self.actor_target(next_state) target_Q = sel.target_critic(next_state, next_action) # Step 9: We get the final target of the two Critic models, which is: Qt = r + γ * min(Qt1, Qt2), where γ is the discount factor target_Q = reward + (not_done * self.discount * target_Q).detach() # again with augmented data next_action_aug = self.actor_target(next_state_aug) target_aug_Q = sel.target_critic(next_state_aug, next_action_aug) target_aug_Q = reward + (not_done * self.discount * target_aug_Q).detach() target_Q = (target_Q + target_aug_Q) / 2. current_Q1 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) # again for augment Q1_aug = self.critic(state_aug, action) critic_loss += F.mse_loss(Q1_aug, target_Q) # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() actor_loss = -self.critic.Q1(detach_state, self.actor(detach_state)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def hardupdate(self): pass def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TCQ(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.decoder = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), args.lr_critic) self.critic = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.critic_target = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(args.device) self.critic_target.load_state_dict(self.critic.state_dict()) self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.top_quantiles_to_drop = args.top_quantiles_to_drop_per_net * args.n_nets self.target_entropy = -np.prod(action_dim) self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=self.device) self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.lr_alpha) self.total_it = 0 def train(self, replay_buffer, writer, iterations): self.step += 1 # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory sys.stdout = open(os.devnull, "w") obs, action, reward, next_obs, not_done, obs_list, next_obs_list = replay_buffer.sample( self.batch_size) sys.stdout = sys.__stdout__ #batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(self.batch_size) #state_image = torch.Tensor(batch_states).to(self.device).div_(255) #next_state = torch.Tensor(batch_next_states).to(self.device).div_(255) # create vector #reward = torch.Tensor(batch_rewards).to(self.device) #done = torch.Tensor(batch_dones).to(self.device) obs = obs.div_(255) next_obs = next_obs.div_(255) state = self.decoder.create_vector(obs) detach_state = state.detach() next_state = self.decoder.create_vector(next_obs) alpha = torch.exp(self.log_alpha) with torch.no_grad(): # Step 5: next_action, next_log_pi = self.actor(next_state) # compute quantile next_z = self.critic_target(next_obs_list, next_action) sorted_z, _ = torch.sort(next_z.reshape(self.batch_size, -1)) sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop] # get target target = reward + not_done * self.discount * (sorted_z_part - alpha * next_log_pi) #---update critic cur_z = self.critic(obs_list, action) critic_loss = quantile_huber_loss_f(cur_z, target, self.device) self.critic_optimizer.zero_grad() self.decoder_optimizer.zero_grad() critic_loss.backward() self.decoder_optimizer.step() self.critic_optimizer.step() for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) #---Update policy and alpha new_action, log_pi = self.actor(detach_state) alpha_loss = -self.log_alpha * (log_pi + self.target_entropy).detach().mean() actor_loss = (alpha * log_pi - self.critic( obs_list, new_action).mean(2).mean(1, keepdim=True)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.total_it += 1 def select_action(self, obs): obs = torch.FloatTensor(obs).to(self.device) obs = obs.div_(255) state = self.decoder.create_vector(obs.unsqueeze(0)) return self.actor.select_action(state) def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TD31v1(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim).to(args.device) self.actor_target = Actor(state_dim, action_dim).to(args.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.list_target_critic = [] for c in range(args.num_q_target): critic_target = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) critic_target.load_state_dict(critic_target.state_dict()) self.list_target_critic.append(critic_target) self.num_q_target = args.num_q_target self.max_action = 1 self.update_counter = 0 self.currentQNet = 0 self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.policy_noise = args.policy_noise self.noise_clip = args.noise_clip self.policy_freq = args.policy_freq self.device = args.device self.actor_clip_gradient = args.actor_clip_gradient self.write_tensorboard = False def select_action(self, state): state = torch.Tensor(state).to(self.device).div_(255) state = state.unsqueeze(0) state = self.critic.create_vector(state) return self.actor(state).cpu().data.numpy().flatten() def train(self, replay_buffer, writer, iterations): self.step += 1 if self.step % 1000 == 0: self.write_tensorboard = 1 - self.write_tensorboard for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( self.batch_size) #batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(self.batch_size) #state_image = torch.Tensor(batch_states).to(self.device).div_(255) #next_state = torch.Tensor(batch_next_states).to(self.device).div_(255) # create vector #reward = torch.Tensor(batch_rewards).to(self.device) #done = torch.Tensor(batch_dones).to(self.device) obs = obs.div_(255) next_obs = next_obs.div_(255) obs_aug = obs_aug.div_(255) next_obs_aug = next_obs_aug.div_(255) state = self.critic.create_vector(obs) detach_state = state.detach() state_aug = self.critic.create_vector(obs_aug) next_state = self.target_critic.create_vector(next_obs) detach_state_aug = state_aug.detach() next_state_aug = self.target_critic.create_vector(next_obs_aug) with torch.no_grad(): # Step 5: From the next state s’, the Actor target plays the next action a’ next_action = self.actor_target(next_state) noise = (torch.randn_like(action) * self.policy_noise).clamp( -self.noise_clip, self.noise_clip) next_action = (next_action + noise).clamp( -self.max_action, self.max_action) # Step 7: The two Critic targets take each the couple (s’, a’) as input and return two Q-values Qt1(s’,a’) and Qt2(s’,a’) as outputs target_Q = 0 for critic in self.list_target_critic: target_Q1, target_Q2 = critic( critic.create_vector(next_obs), next_action) target_Q += torch.min(target_Q1, target_Q2) target_Q *= 1. / self.num_q_target # Step 9: We get the final target of the two Critic models, which is: Qt = r + γ * min(Qt1, Qt2), where γ is the discount factor target_Q = reward + (not_done * self.discount * target_Q).detach() # again with augmented data next_action_aug = self.actor_target(next_state_aug) noise = (torch.randn_like(action) * self.policy_noise).clamp( -self.noise_clip, self.noise_clip) next_action_aug = (next_action_aug + noise).clamp( -self.max_action, self.max_action) target_aug_Q = 0 for idx, critic in enumerate(self.list_target_critic): target_Q1, target_Q2 = critic( critic.create_vector(next_obs_aug), next_action_aug) target_aug_Q_min = torch.min(target_Q1, target_Q2) if self.write_tensorboard: writer.add_scalar('Critic-{} q'.format(idx), target_aug_Q_min.mean(), self.step) target_aug_Q += target_aug_Q_min target_aug_Q *= 1. / self.num_q_target target_aug_Q = reward + (not_done * self.discount * target_aug_Q).detach() target_Q = (target_Q + target_aug_Q) / 2. current_Q1, current_Q2 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) if self.write_tensorboard: writer.add_scalar('Critic loss', critic_loss, self.step) # again for augment Q1_aug, Q2_aug = self.critic(state_aug, action) critic_loss += F.mse_loss(Q1_aug, target_Q) + F.mse_loss( Q2_aug, target_Q) # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Step 13: Once every two iterations, we update our Actor model by performing gradient ascent on the output of the first Critic model if it % self.policy_freq == 0: # print("cuurent", self.currentQNet) obs = replay_buffer.sample_actor(self.batch_size) obs = obs.div_(255) state = self.critic.create_vector(obs) actor_loss = -self.critic.Q1(state, self.actor(state)).mean() if self.write_tensorboard: writer.add_scalar('Actor loss', actor_loss, self.step) self.actor_optimizer.zero_grad() #actor_loss.backward(retain_graph=True) actor_loss.backward() # clip gradient # self.actor.clip_grad_value(self.actor_clip_gradient) torch.nn.utils.clip_grad_value_(self.actor.parameters(), self.actor_clip_gradient) self.actor_optimizer.step() for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip( self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def hardupdate(self): self.update_counter += 1 self.currentQNet = self.update_counter % self.num_q_target for param, target_param in zip( self.target_critic.parameters(), self.list_target_critic[self.currentQNet].parameters()): target_param.data.copy_(param.data) def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TD31v1(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim).to(args.device) self.actor_target = Actor(state_dim, action_dim).to(args.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.max_action = 1 self.update_counter = 0 self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.policy_noise = args.policy_noise self.noise_clip = args.noise_clip self.policy_freq = args.policy_freq self.device = args.device self.actor_clip_gradient = args.actor_clip_gradient def select_action(self, state): state = torch.Tensor(state).to(self.device).div_(255) state = state.unsqueeze(0) state = self.critic.create_vector(state) return self.actor(state).cpu().data.numpy().flatten() def train(self, replay_buffer, writer, iterations): self.step += 1 for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory obs, action, reward, next_obs, not_done, = replay_buffer.sample( self.batch_size) #batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(self.batch_size) #state_image = torch.Tensor(batch_states).to(self.device).div_(255) #next_state = torch.Tensor(batch_next_states).to(self.device).div_(255) # create vector #reward = torch.Tensor(batch_rewards).to(self.device) #done = torch.Tensor(batch_dones).to(self.device) obs = obs.div_(255) next_obs = next_obs.div_(255) state = self.critic.create_vector(obs) detach_state = state.detach() next_state = self.critic.create_vector(next_obs) with torch.no_grad(): # Step 5: From the next state s’, the Actor target plays the next action a’ next_action = self.actor_target(next_state) noise = (torch.randn_like(action) * self.policy_noise).clamp( -self.noise_clip, self.noise_clip) next_action = (next_action + noise).clamp( -self.max_action, self.max_action) # Step 7: The two Critic targets take each the couple (s’, a’) as input and return two Q-values Qt1(s’,a’) and Qt2(s’,a’) as outputs target_Q1, target_Q2 = self.target_critic( next_state, next_action) target_Q = torch.min(target_Q1, target_Q2) # Step 9: We get the final target of the two Critic models, which is: Qt = r + γ * min(Qt1, Qt2), where γ is the discount factor target_Q = reward + (not_done * self.discount * target_Q).detach() current_Q1, current_Q2 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Step 13: Once every two iterations, we update our Actor model by performing gradient ascent on the output of the first Critic model if it % self.policy_freq == 0: # print("cuurent", self.currentQNet) actor_loss = -self.critic.Q1(detach_state, self.actor(detach_state)).mean() self.actor_optimizer.zero_grad() #actor_loss.backward(retain_graph=True) actor_loss.backward() # clip gradient # self.actor.clip_grad_value(self.actor_clip_gradient) torch.nn.utils.clip_grad_value_(self.actor.parameters(), self.actor_clip_gradient) self.actor_optimizer.step() for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip( self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def hardupdate(self): pass def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TD31v1(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim).to(args.device) self.actor_target = Actor(state_dim, action_dim).to(args.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.list_target_critic = [] for c in range(args.num_q_target): critic_target = CNNCritic_target(state_dim, action_dim).to(args.device) critic_target.load_state_dict(critic_target.state_dict()) self.list_target_critic.append(critic_target) self.target_critic = CNNCritic_target(state_dim, action_dim).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.max_action = 1 self.num_q_target = args.num_q_target self.update_counter = 0 self.step = 0 self.currentQNet = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.policy_noise = args.policy_noise self.noise_clip = args.noise_clip self.policy_freq = args.policy_freq self.device = args.device self.actor_clip_gradient = args.actor_clip_gradient def select_action(self, state): state = torch.Tensor(state).to(self.device).div_(255) state = state.unsqueeze(0) state = self.critic.create_vector(state) return self.actor(state).cpu().data.numpy().flatten() def train(self, replay_buffer, writer, iterations): self.step += 1 for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample( self.batch_size) state_image = torch.Tensor(batch_states).to(self.device).div_(255) next_state = torch.Tensor(batch_next_states).to( self.device).div_(255) # create vector action = torch.Tensor(batch_actions).to(self.device) reward = torch.Tensor(batch_rewards).to(self.device) done = torch.Tensor(batch_dones).to(self.device) state = self.critic.create_vector(state_image) detach_state = state.detach() next_state = self.critic.create_vector(next_state) with torch.no_grad(): # Step 5: From the next state s’, the Actor target plays the next action a’ next_action = self.actor_target(next_state) # Step 6: We add Gaussian noise to this next action a’ and we clamp it in a range of values supported by the environment noise = torch.Tensor(batch_actions).data.normal_( 0, self.policy_noise).to(self.device) noise = noise.clamp(-self.noise_clip, self.noise_clip) next_action = (next_action + noise).clamp( -self.max_action, self.max_action) # Step 7: The two Critic targets take each the couple (s’, a’) as input and return two Q-values Qt1(s’,a’) and Qt2(s’,a’) as outputs target_Q = 0 for critic in self.list_target_critic: target_Q1, target_Q2 = critic(next_state, next_action) target_Q += torch.min(target_Q1, target_Q2) target_Q *= 1. / self.num_q_target # Step 9: We get the final target of the two Critic models, which is: Qt = r + γ * min(Qt1, Qt2), where γ is the discount factor target_Q = reward + ( (1 - done) * self.discount * target_Q).detach() # Step 10: The two Critic models take each the couple (s, a) as input and return two Q-values Q1(s,a) and Q2(s,a) as outputs #state = self.critic.create_vector(state_image, False) current_Q1, current_Q2 = self.critic(state, action) # Step 11: We compute the loss coming from the two Critic models: Critic Loss = MSE_Loss(Q1(s,a), Qt) + MSE_Loss(Q2(s,a), Qt) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Step 13: Once every two iterations, we update our Actor model by performing gradient ascent on the output of the first Critic model if it % self.policy_freq == 0: # print("cuurent", self.currentQNet) actor_loss = -self.critic.Q1(detach_state, self.actor(detach_state)).mean() self.actor_optimizer.zero_grad() #actor_loss.backward(retain_graph=True) actor_loss.backward() # clip gradient # self.actor.clip_grad_value(self.actor_clip_gradient) torch.nn.utils.clip_grad_value_(self.actor.parameters(), self.actor_clip_gradient) self.actor_optimizer.step() for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip( self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def hardupdate(self): self.update_counter += 1 self.currentQNet = self.update_counter % self.num_q_target # Step 15: Still once every two iterations, we update the weights of the Critic target by polyak averaging for param, target_param in zip( self.target_critic.parameters(), self.list_target_critic[self.currentQNet].parameters()): target_param.data.copy_(param.data) def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TQC(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.write_tensorboard = False self.top_quantiles_to_drop = args.top_quantiles_to_drop self.target_entropy = args.target_entropy self.quantiles_total = critic.n_quantiles * critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=args.device) self.total_it = 0 self.step = 0 def train(self, replay_buffer, writer, iterations): self.step += 1 if self.step % 1000 == 0: self.write_tensorboard = 1 - self.write_tensorboard for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( self.batch_size) obs = obs.div_(255) next_obs = next_obs.div_(255) obs_aug = obs_aug.div_(255) next_obs_aug = next_obs_aug.div_(255) state = self.critic.create_vector(obs) detach_state = state.detach() state_aug = self.critic.create_vector(obs_aug) next_state = self.target_critic.create_vector(next_obs) detach_state_aug = state_aug.detach() next_state_aug = self.target_critic.create_vector(next_obs_aug) alpha = torch.exp(self.log_alpha) with torch.no_grad(): # Step 5: Get policy action new_next_action, next_log_pi = self.actor(next_state) # compute quantile at next state next_z = self.critic_target(next_state, new_next_action) sorted_z, _ = torch.sort(next_z.reshape(batch_size, -1)) sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop] current_Q1, current_Q2 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) if self.write_tensorboard: writer.add_scalar('Critic loss', critic_loss, self.step) # again for augment Q1_aug, Q2_aug = self.critic(state_aug, action) critic_loss += F.mse_loss(Q1_aug, target_Q) + F.mse_loss( Q2_aug, target_Q) # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Step 13: Once every two iterations, we update our Actor model by performing gradient ascent on the output of the first Critic model if it % self.policy_freq == 0: # print("cuurent", self.currentQNet) obs = replay_buffer.sample_actor(self.batch_size) obs = obs.div_(255) state = self.critic.create_vector(obs) actor_loss = -self.critic.Q1(state, self.actor(state)).mean() if self.write_tensorboard: writer.add_scalar('Actor loss', actor_loss, self.step) self.actor_optimizer.zero_grad() #actor_loss.backward(retain_graph=True) actor_loss.backward() # clip gradient # self.actor.clip_grad_value(self.actor_clip_gradient) torch.nn.utils.clip_grad_value_(self.actor.parameters(), self.actor_clip_gradient) self.actor_optimizer.step() for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip( self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def hardupdate(self): self.update_counter += 1 self.currentQNet = self.update_counter % self.num_q_target for param, target_param in zip( self.target_critic.parameters(), self.list_target_critic[self.currentQNet].parameters()): target_param.data.copy_(param.data) def quantile_huber_loss_f(self, quantiles, samples): pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] # batch x nets x quantiles x samples abs_pairwise_delta = torch.abs(pairwise_delta) huber_loss = torch.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) n_quantiles = quantiles.shape[2] tau = torch.arange(n_quantiles, device=self.device).float( ) / n_quantiles + 1 / 2 / n_quantiles loss = (torch.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() return loss def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)