class TCQ(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [3, 84, 84] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.decoder = CNNCritic(input_dim, state_dim, action_dim).to(args.device) self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), args.lr_critic) self.critic = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.critic_target = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(args.device) self.critic_target.load_state_dict(self.critic.state_dict()) self.step = 0 self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.top_quantiles_to_drop = args.top_quantiles_to_drop_per_net * args.n_nets self.target_entropy = -np.prod(action_dim) self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=self.device) self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.lr_alpha) self.total_it = 0 def train(self, replay_buffer, writer, iterations): self.step += 1 # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory sys.stdout = open(os.devnull, "w") obs, action, reward, next_obs, not_done, obs_list, next_obs_list = replay_buffer.sample( self.batch_size) sys.stdout = sys.__stdout__ #batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(self.batch_size) #state_image = torch.Tensor(batch_states).to(self.device).div_(255) #next_state = torch.Tensor(batch_next_states).to(self.device).div_(255) # create vector #reward = torch.Tensor(batch_rewards).to(self.device) #done = torch.Tensor(batch_dones).to(self.device) obs = obs.div_(255) next_obs = next_obs.div_(255) state = self.decoder.create_vector(obs) detach_state = state.detach() next_state = self.decoder.create_vector(next_obs) alpha = torch.exp(self.log_alpha) with torch.no_grad(): # Step 5: next_action, next_log_pi = self.actor(next_state) # compute quantile next_z = self.critic_target(next_obs_list, next_action) sorted_z, _ = torch.sort(next_z.reshape(self.batch_size, -1)) sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop] # get target target = reward + not_done * self.discount * (sorted_z_part - alpha * next_log_pi) #---update critic cur_z = self.critic(obs_list, action) critic_loss = quantile_huber_loss_f(cur_z, target, self.device) self.critic_optimizer.zero_grad() self.decoder_optimizer.zero_grad() critic_loss.backward() self.decoder_optimizer.step() self.critic_optimizer.step() for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) #---Update policy and alpha new_action, log_pi = self.actor(detach_state) alpha_loss = -self.log_alpha * (log_pi + self.target_entropy).detach().mean() actor_loss = (alpha * log_pi - self.critic( obs_list, new_action).mean(2).mean(1, keepdim=True)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.total_it += 1 def select_action(self, obs): obs = torch.FloatTensor(obs).to(self.device) obs = obs.div_(255) state = self.decoder.create_vector(obs.unsqueeze(0)) return self.actor.select_action(state) def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TQC(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = Critic(state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = Critic(state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.critic.state_dict()) self.decoder = Decoder(args).to(args.device) self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), args.lr_decoder) self.target_decoder = Decoder(args).to(args.device) self.target_decoder.load_state_dict(self.decoder.state_dict()) self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.write_tensorboard = False self.top_quantiles_to_drop = args.top_quantiles_to_drop_per_net * args.n_nets self.n_nets = args.n_nets self.top_quantiles_to_drop_per_net = args.top_quantiles_to_drop_per_net self.target_entropy = args.target_entropy self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=args.device) self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.lr_alpha) self.total_it = 0 self.step = 0 self.beta = 0.5 def update_beta(self, replay_buffer, writer, total_timesteps): obs, obs_aug, action, reward, not_done = replay_buffer.get_last_k_trajectories( ) # not done are 0 if episodes ends store_batch = [] # create a R_i for the k returns from the buffer i = 0 obs = obs.div_(255) obs_aug = obs_aug.div_(255) Ri = [] tmp = 0 # only for one episode k = obs.shape[0] for idx in range(obs.shape[0]): if not_done[idx][0] == 0: # dont forget to add last reward Ri.append(reward[idx][0] * self.discount**(k)) # Ri_tmp.reverse() #R_i.append(Ri_tmp) break tmp += self.discount**(k - i) * reward[idx][0] #print(tmp) Ri.append(deepcopy(tmp)) i += 1 #print(Ri) #print(len(Ri)) for idx, ri in enumerate(Ri): store_batch.append((obs[idx], obs_aug[idx], action[idx], ri)) delta = 0 # print(store_batch) for b in store_batch: s, s1, a, r = b[0], b[1], b[2], b[3].data.item() a = a.unsqueeze(0) s = s.unsqueeze(0) s1 = s1.unsqueeze(0) #r = torch.Tensor(np.array([r])) #r = r.unsqueeze(1).to(self.device) # first augment state_aug = self.decoder.create_vector(s) next_z = self.critic(state_aug.detach(), a.detach()) Q = 0 for net in next_z[0]: Q += torch.mean(net).data.item() Q *= 1. / self.n_nets # sec augment state_aug1 = self.decoder.create_vector(s1) next_z1 = self.critic(state_aug1.detach(), a.detach()) Q_aug = 0 for net in next_z1[0]: Q_aug += torch.mean(net).data.item() Q_aug *= 1. / self.n_nets Q_all = (Q + Q_aug) / 2. dif = Q_all - r text = "Predicted Q: {} return r: {} dif {}".format( Q_all, r, dif) write_into_file("debug_beta", text) delta += dif delta *= (1. / len(store_batch)) writer.add_scalar('delta', delta, total_timesteps) dif = self.beta - delta writer.add_scalar('dif', dif, total_timesteps) self.beta = max(0., min(dif, 1)) writer.add_scalar('beta', self.beta, total_timesteps) # compute how many quntile to drop if beta is greate 0.5 # if beta gets closer to 1 drop more qantuile if self.beta > 0.8 and self.top_quantiles_to_drop > 0: self.top_quantiles_to_drop -= 1 if self.beta < 0.2 and self.top_quantiles_to_drop < 35: self.top_quantiles_to_drop += 1 text = "Delta d: {} beta : {} current topquantile to drop {}".format( delta, self.beta, self.top_quantiles_to_drop) write_into_file("debug_beta", text) writer.add_scalar('drop-qunantile', self.top_quantiles_to_drop, total_timesteps) #self.top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.n_nets def train(self, replay_buffer, writer, iterations): self.step += 1 if self.step % 1000 == 0: self.write_tensorboard = 1 - self.write_tensorboard for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memoy sys.stdout = open(os.devnull, "w") obs, action, reward, next_obs, not_done, obs_aug, obs_next_aug = replay_buffer.sample( self.batch_size) sys.stdout = sys.__stdout__ # for augment 1 obs = obs.div_(255) next_obs = next_obs.div_(255) state = self.decoder.create_vector(obs) detach_state = state.detach() next_state = self.target_decoder.create_vector(next_obs) # for augment 2 obs_aug = obs_aug.div_(255) next_obs_aug = obs_next_aug.div_(255) state_aug = self.decoder.create_vector(obs_aug) detach_state_aug = state_aug.detach() next_state_aug = self.target_decoder.create_vector(next_obs_aug) alpha = torch.exp(self.log_alpha) with torch.no_grad(): # Step 5: Get policy action new_next_action, next_log_pi = self.actor(next_state) # compute quantile at next state next_z = self.target_critic(next_state, new_next_action) sorted_z, _ = torch.sort(next_z.reshape(self.batch_size, -1)) sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop] target = reward + not_done * self.discount * ( sorted_z_part - alpha * next_log_pi) # again for augment new_next_action_aug, next_log_pi_aug = self.actor( next_state_aug) next_z_aug = self.target_critic(next_state_aug, new_next_action_aug) sorted_z_aug, _ = torch.sort( next_z_aug.reshape(self.batch_size, -1)) sorted_z_part_aug = sorted_z_aug[:, :self.quantiles_total - self.top_quantiles_to_drop] target_aug = reward + not_done * self.discount * ( sorted_z_part_aug - alpha * next_log_pi_aug) target = (target + target_aug) / 2. #---update critic cur_z = self.critic(state, action) #print("curz shape", cur_z.shape) #print("target shape", target.shape) critic_loss = quantile_huber_loss_f(cur_z, target, self.device) # for augment cur_z_aug = self.critic(state_aug, action) critic_loss += quantile_huber_loss_f(cur_z_aug, target, self.device) self.critic_optimizer.zero_grad() self.decoder_optimizer.zero_grad() critic_loss.backward() self.decoder_optimizer.step() self.critic_optimizer.step() for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.decoder.parameters(), self.target_decoder.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) #---Update policy and alpha new_action, log_pi = self.actor(detach_state) alpha_loss = -self.log_alpha * ( log_pi + self.target_entropy).detach().mean() actor_loss = (alpha * log_pi - self.critic(detach_state, new_action).mean(2).mean( 1, keepdim=True)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.total_it += 1 def select_action(self, obs): obs = torch.FloatTensor(obs).to(self.device) obs = obs.div_(255) state = self.decoder.create_vector(obs.unsqueeze(0)) return self.actor.select_action(state) def quantile_huber_loss_f(self, quantiles, samples): pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] # batch x nets x quantiles x samples abs_pairwise_delta = torch.abs(pairwise_delta) huber_loss = torch.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) n_quantiles = quantiles.shape[2] tau = torch.arange(n_quantiles, device=self.device).float( ) / n_quantiles + 1 / 2 / n_quantiles loss = (torch.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() return loss def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") torch.save(self.decoder.state_dict(), filename + "_decoder") torch.save(self.decoder_optimizer.state_dict(), filename + "_decoder_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor) self.decoder.load_state_dict(torch.load(filename + "_decoder")) self.decoder_optimizer.load_state_dict( torch.load(filename + "_decoder_optimizer"))
class TQC(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = CNNCritic(input_dim, state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.target_critic.state_dict()) self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.write_tensorboard = False self.top_quantiles_to_drop = args.top_quantiles_to_drop self.target_entropy = args.target_entropy self.quantiles_total = critic.n_quantiles * critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=args.device) self.total_it = 0 self.step = 0 def train(self, replay_buffer, writer, iterations): self.step += 1 if self.step % 1000 == 0: self.write_tensorboard = 1 - self.write_tensorboard for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( self.batch_size) obs = obs.div_(255) next_obs = next_obs.div_(255) obs_aug = obs_aug.div_(255) next_obs_aug = next_obs_aug.div_(255) state = self.critic.create_vector(obs) detach_state = state.detach() state_aug = self.critic.create_vector(obs_aug) next_state = self.target_critic.create_vector(next_obs) detach_state_aug = state_aug.detach() next_state_aug = self.target_critic.create_vector(next_obs_aug) alpha = torch.exp(self.log_alpha) with torch.no_grad(): # Step 5: Get policy action new_next_action, next_log_pi = self.actor(next_state) # compute quantile at next state next_z = self.critic_target(next_state, new_next_action) sorted_z, _ = torch.sort(next_z.reshape(batch_size, -1)) sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop] current_Q1, current_Q2 = self.critic(state, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) if self.write_tensorboard: writer.add_scalar('Critic loss', critic_loss, self.step) # again for augment Q1_aug, Q2_aug = self.critic(state_aug, action) critic_loss += F.mse_loss(Q1_aug, target_Q) + F.mse_loss( Q2_aug, target_Q) # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Step 13: Once every two iterations, we update our Actor model by performing gradient ascent on the output of the first Critic model if it % self.policy_freq == 0: # print("cuurent", self.currentQNet) obs = replay_buffer.sample_actor(self.batch_size) obs = obs.div_(255) state = self.critic.create_vector(obs) actor_loss = -self.critic.Q1(state, self.actor(state)).mean() if self.write_tensorboard: writer.add_scalar('Actor loss', actor_loss, self.step) self.actor_optimizer.zero_grad() #actor_loss.backward(retain_graph=True) actor_loss.backward() # clip gradient # self.actor.clip_grad_value(self.actor_clip_gradient) torch.nn.utils.clip_grad_value_(self.actor.parameters(), self.actor_clip_gradient) self.actor_optimizer.step() for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip( self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def hardupdate(self): self.update_counter += 1 self.currentQNet = self.update_counter % self.num_q_target for param, target_param in zip( self.target_critic.parameters(), self.list_target_critic[self.currentQNet].parameters()): target_param.data.copy_(param.data) def quantile_huber_loss_f(self, quantiles, samples): pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] # batch x nets x quantiles x samples abs_pairwise_delta = torch.abs(pairwise_delta) huber_loss = torch.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) n_quantiles = quantiles.shape[2] tau = torch.arange(n_quantiles, device=self.device).float( ) / n_quantiles + 1 / 2 / n_quantiles loss = (torch.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() return loss def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)
class TQC(object): def __init__(self, state_dim, action_dim, actor_input_dim, args): input_dim = [args.history_length, args.size, args.size] self.actor = Actor(state_dim, action_dim, args).to(args.device) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), args.lr_actor) self.critic = Critic(state_dim, action_dim, args).to(args.device) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), args.lr_critic) self.target_critic = Critic(state_dim, action_dim, args).to(args.device) self.target_critic.load_state_dict(self.critic.state_dict()) self.decoder = Decoder(args).to(args.device) self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), args.lr_decoder) self.target_decoder = Decoder(args).to(args.device) self.target_decoder.load_state_dict(self.decoder.state_dict()) self.batch_size = args.batch_size self.discount = args.discount self.tau = args.tau self.device = args.device self.write_tensorboard = False self.top_quantiles_to_drop = args.top_quantiles_to_drop_per_net * args.n_nets self.target_entropy = args.target_entropy self.quantiles_total = self.critic.n_quantiles * self.critic.n_nets self.log_alpha = torch.zeros((1, ), requires_grad=True, device=args.device) self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.lr_alpha) self.total_it = 0 self.step = 0 def train(self, replay_buffer, writer, iterations): self.step += 1 if self.step % 1000 == 0: self.write_tensorboard = 1 - self.write_tensorboard for it in range(iterations): # Step 4: We sample a batch of transitions (s, s’, a, r) from the memoy sys.stdout = open(os.devnull, "w") obs, action, reward, next_obs, not_done, obs_aug, obs_next_aug = replay_buffer.sample( self.batch_size) sys.stdout = sys.__stdout__ # for augment 1 obs = obs.div_(255) next_obs = next_obs.div_(255) state = self.decoder.create_vector(obs) detach_state = state.detach() next_state = self.target_decoder.create_vector(next_obs) # for augment 2 obs_aug = obs_aug.div_(255) next_obs_aug = obs_next_aug.div_(255) state_aug = self.decoder.create_vector(obs_aug) detach_state_aug = state_aug.detach() next_state_aug = self.target_decoder.create_vector(next_obs_aug) alpha = torch.exp(self.log_alpha) with torch.no_grad(): # Step 5: Get policy action new_next_action, next_log_pi = self.actor(next_state) # compute quantile at next state next_z = self.target_critic(next_state, new_next_action) sorted_z, _ = torch.sort(next_z.reshape(self.batch_size, -1)) sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop] target = reward + not_done * self.discount * ( sorted_z_part - alpha * next_log_pi) # again for augment new_next_action_aug, next_log_pi_aug = self.actor( next_state_aug) next_z_aug = self.target_critic(next_state_aug, new_next_action_aug) sorted_z_aug, _ = torch.sort( next_z_aug.reshape(self.batch_size, -1)) sorted_z_part_aug = sorted_z_aug[:, :self.quantiles_total - self.top_quantiles_to_drop] target_aug = reward + not_done * self.discount * ( sorted_z_part_aug - alpha * next_log_pi_aug) target = (target + target_aug) / 2. #---update critic cur_z = self.critic(state, action) critic_loss = quantile_huber_loss_f(cur_z, target, self.device) # for augment cur_z_aug = self.critic(state_aug, action) critic_loss += quantile_huber_loss_f(cur_z_aug, target, self.device) self.critic_optimizer.zero_grad() self.decoder_optimizer.zero_grad() critic_loss.backward() self.decoder_optimizer.step() self.critic_optimizer.step() for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.decoder.parameters(), self.target_decoder.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) #---Update policy and alpha new_action, log_pi = self.actor(detach_state) alpha_loss = -self.log_alpha * ( log_pi + self.target_entropy).detach().mean() actor_loss = (alpha * log_pi - self.critic(detach_state, new_action).mean(2).mean( 1, keepdim=True)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.total_it += 1 def select_action(self, obs): obs = torch.FloatTensor(obs).to(self.device) obs = obs.div_(255) state = self.decoder.create_vector(obs.unsqueeze(0)) return self.actor.select_action(state) def quantile_huber_loss_f(self, quantiles, samples): pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] # batch x nets x quantiles x samples abs_pairwise_delta = torch.abs(pairwise_delta) huber_loss = torch.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) n_quantiles = quantiles.shape[2] tau = torch.arange(n_quantiles, device=self.device).float( ) / n_quantiles + 1 / 2 / n_quantiles loss = (torch.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() return loss def save(self, filename): torch.save(self.critic.state_dict(), filename + "_critic") torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") torch.save(self.actor.state_dict(), filename + "_actor") torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") def load(self, filename): self.critic.load_state_dict(torch.load(filename + "_critic")) self.critic_optimizer.load_state_dict( torch.load(filename + "_critic_optimizer")) self.critic_target = copy.deepcopy(self.critic) self.actor.load_state_dict(torch.load(filename + "_actor")) self.actor_optimizer.load_state_dict( torch.load(filename + "_actor_optimizer")) self.actor_target = copy.deepcopy(self.actor)