def output_mse(result_fx, result_fl): from torch.functional import F _, train_fx = result_fx _, train_fl = result_fl loss = [abs(F.mse_loss(x, y)) for x, y in zip(train_fx, train_fl)] return sum(loss)
def train(self, observations, actions): actions = torch.from_numpy(actions).float().to(self._device) observations = torch.from_numpy(observations).float().to(self._device) _actionsPred = self._backbone(observations) # compute loss (gaussian policy -> mse-loss) self._optimizer.zero_grad() _loss = F.mse_loss(_actionsPred, actions) _loss.backward() # update the weights of our backbone self._optimizer.step() # grab loss value self._currentTrainLoss = _loss.item() # log training information --------------------------------------------- if not self._logger: # sanity-check assert self._logpath != None, 'ERROR> logpath should be defined' # create logger once self._logger = SummaryWriter(self._logpath) self._logger.add_scalar('log_1_loss', self._currentTrainLoss, self._istep) # ---------------------------------------------------------------------- # book keeping self._istep += 1
def evaluate(self, sample, model_out): # Calculate downstream FC + MLP loss/accuracies results = super().evaluate(sample, model_out, multi_out=True) inp_img = sample[0] out_img, mu, logvar = [tmp_out.float() for tmp_out in model_out[1]] out_res = out_img.shape[2] # Reconstruction loss target_img = F.interpolate(inp_img.float(), size=out_res) MSE = F.mse_loss(target_img, out_img) # KL divergence if self.lmd > 0: KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) else: KLD = torch.Tensor([0])[0] vae_loss = MSE + self.lmd * KLD loss = vae_loss + results['loss_downstream'] results['loss_vae'] = vae_loss results['loss_mse'] = MSE results['loss_kld'] = KLD return loss, results
def test_fake_task(): X = torch.rand((5000, 100)) * 100 A = torch.rand((100, 10)) y = X @ A X = X.cuda() y = y.cuda() mlp = MLP(100, [], 10).cuda() opt = torch.optim.Adam(mlp.parameters()) for _ in range(10000): loss = F.mse_loss(y, mlp.forward(X)) opt.zero_grad() loss.backward() opt.step() assert F.mse_loss(mlp.state_dict()["mlp.0.weight"].T.cpu(), A).item() < 1e-2
def train(self, obs: dict) -> None: obs, z = self._split_obs(obs) mu, var = self.model.forward(obs).split(self.num_skills, 1) pred_z = pyd.Normal(mu, torch.exp(var)).rsample() loss = F.mse_loss(pred_z, z) self.log("loss", loss) self.optim.zero_grad() loss.backward() self.optim.step()
def get_style_loss(self, y_): n = y_.size(0) y_ = self.get_style_features(y_) y = self.style_features l_style = 0 for g_, g in zip(y_, y): l_style += F.mse_loss(g_, g.expand_as(g_), reduction='sum') return l_style / n
def train(self, states, actions, qtargets): with autograd.detect_anomaly(): # transform to torch tensors states = torch.from_numpy(states).float().to(self._device) actions = torch.from_numpy(actions).float().to(self._device) qtargets = torch.from_numpy(qtargets).float().to(self._device) # compute q-values for Q(s,a), where s,a come from the given ... # states and actions batches passed along the q-targets _qvalues = self._backbone([states, actions]) # compute loss for the critic self._optimizer.zero_grad() _lossCritic = F.mse_loss(_qvalues, qtargets) _lossCritic.backward() if self._backbone.config.clipGradients: nn.utils.clip_grad_norm( self._backbone.parameters(), self._backbone.config.gradientsClipNorm) # take a step with the optimizer self._optimizer.step()
def calculate_model_losses(args, skip_pixel_loss, img, img_pred, bbox, bbox_pred): total_loss = torch.zeros(1).to(img) losses = {} l1_pixel_weight = args.l1_pixel_loss_weight if skip_pixel_loss: l1_pixel_weight = 0 l1_pixel_loss = F.l1_loss(img_pred, img) total_loss = add_loss(total_loss, l1_pixel_loss, losses, 'L1_pixel_loss', l1_pixel_weight) loss_bbox = F.mse_loss(bbox_pred, bbox) total_loss = add_loss(total_loss, loss_bbox, losses, 'bbox_pred', args.bbox_pred_loss_weight) return total_loss, losses
def train( env, num_episodes = 2000 ) : _actorNetLocal = PiNetwork( env.observation_space.shape, env.action_space.shape ).to( DEVICE ) _actorNetTarget = PiNetwork( env.observation_space.shape, env.action_space.shape ).to( DEVICE ) _actorNetTarget.copy( _actorNetLocal ) _criticNetLocal = Qnetwork( env.observation_space.shape, env.action_space.shape ).to( DEVICE ) _criticNetTarget = Qnetwork( env.observation_space.shape, env.action_space.shape ).to( DEVICE ) _criticNetTarget.copy( _criticNetLocal ) _rbuffer = ReplayBuffer( REPLAY_BUFFER_SIZE ) ## _noise = OUNoise( env.action_space.shape ) _noise = OUNoise2( env.action_space.shape ) _optimActor = opt.Adam( _actorNetLocal.parameters(), lr = LEARNING_RATE_ACTOR ) _optimCritic = opt.Adam( _criticNetLocal.parameters(), lr = LEARNING_RATE_CRITIC, weight_decay = WEIGHT_DECAY ) progressbar = tqdm( range( 1, num_episodes + 1 ), desc = 'Training>' ) scoresAvgs = [] scoresWindow = deque( maxlen = LOG_WINDOW ) bestScore = -np.inf avgScore = -np.inf writer = SummaryWriter( 'summary_bipedal_bn' ) istep = 0 epsilon = 1.0 for iepisode in progressbar : _s = env.reset() _noise.reset() _score = 0. for _ in range( MAX_STEPS_IN_EPISODE ) : if istep < TRAINING_STARTING_STEP : _a = np.clip( np.random.randn( *env.action_space.shape ), -1., 1. ) else : # choose an action using the actor network and a noise process _actorNetLocal.eval() with torch.no_grad() : _a = _actorNetLocal( torch.from_numpy( _s ).unsqueeze( 0 ).float().to( DEVICE ) ).cpu().data.numpy().squeeze() _actorNetLocal.train() # add noise and clip accordingly _a += epsilon * _noise.sample() _a = np.clip( _a, -1., 1. ) # take action in the environment and grab bounty _snext, _r, _done, _ = env.step( _a ) _rbuffer.store( (_s, _a, _r, _snext, _done ) ) if len( _rbuffer ) > BATCH_SIZE and istep % TRAIN_FREQUENCY_STEPS == 0 and \ istep >= TRAINING_STARTING_STEP : ## set_trace() # grab a batch of data from the replay buffer _states, _actions, _rewards, _statesNext, _dones = _rbuffer.sample( BATCH_SIZE ) # compute current q-values for the 'actions' taken at 'states' using critic _qvalues = _criticNetLocal( _states, _actions ) # compute target q-values using both target actor and critic with torch.no_grad() : _actionsNext = _actorNetTarget( _statesNext ) _qvaluesTarget = _rewards + ( 1. - _dones ) * GAMMA * _criticNetTarget( _statesNext, _actionsNext ) # compute loss for the critic _optimCritic.zero_grad() _lossCritic = F.mse_loss( _qvalues, _qvaluesTarget ) _lossCritic.backward() _optimCritic.step() # compute loss for the actor, from the objective to "maximize": # # dJ / dtheta = E [ dQ / du * du / dtheta ] # # where: # * theta: weights of the actor # * dQ / du : gradient of Q w.r.t. u (actions taken) # * du / dtheta : gradient of the Actor's weights _optimActor.zero_grad() # compute actions taken in these states by the current state of the actor _actionsPred = _actorNetLocal( _states ) # compose the critic over the actor outputs (sandwich), which effectively does g(f(x)) _lossActor = -_criticNetLocal( _states, _actionsPred ).mean() _lossActor.backward() _optimActor.step() # update target networks _actorNetTarget.copy( _actorNetLocal, TAU ) _criticNetTarget.copy( _criticNetLocal, TAU ) # book keeping for next iteration _s = _snext _score += _r istep += 1 if _done : break # update epsilon using schedule if istep >= TRAINING_STARTING_STEP : epsilon = max( 0.1, epsilon * EPSILON_DECAY_FACTOR ) # update some info for logging bestScore = max( bestScore, _score ) scoresWindow.append( _score ) if iepisode >= LOG_WINDOW : avgScore = np.mean( scoresWindow ) scoresAvgs.append( avgScore ) message = 'Training> best: %.2f - mean: %.2f - current: %.2f' progressbar.set_description( message % ( bestScore, avgScore, _score ) ) progressbar.refresh() else : message = 'Training> best: %.2f - current : %.2f' progressbar.set_description( message % ( bestScore, _score ) ) progressbar.refresh() writer.add_scalar( 'score', _score, iepisode ) writer.add_scalar( 'avg_score', np.mean( scoresWindow ), iepisode ) writer.add_scalar( 'buffer_size', len( _rbuffer ), iepisode ) writer.add_scalar( 'epsilon', epsilon, iepisode ) torch.save( _actorNetLocal.state_dict(), './saved/pytorch/ddpg_actor_bipedal.pth' ) torch.save( _criticNetLocal.state_dict(), './saved/pytorch/ddpg_critic_bipedal.pth' )
def get_content_loss(self, y_, y): y_ = self.backbone_content(y_) y = self.backbone_content(y) return F.mse_loss(y_, y)
def train( env, seed, num_episodes ) : ##------------- Create actor network (+its target counterpart)------------## actorsNetsLocal = [ PiNetwork( env.observation_space.shape, env.action_space.shape, seed ) for _ in range( NUM_AGENTS ) ] actorsNetsTarget = [ PiNetwork( env.observation_space.shape, env.action_space.shape, seed ) for _ in range( NUM_AGENTS ) ] for _netLocal, _netTarget in zip( actorsNetsLocal, actorsNetsTarget ) : _netTarget.copy( _netLocal ) _netLocal.to( DEVICE ) _netTarget.to( DEVICE ) optimsActors = [ opt.Adam( _actorNet.parameters(), lr = LEARNING_RATE_ACTOR ) \ for _actorNet in actorsNetsLocal ] # print a brief summary of the network summary( actorsNetsLocal[0], env.observation_space.shape ) print( actorsNetsLocal[0] ) ##----------- Create critic network (+its target counterpart)-------------## criticsNetsLocal = [ Qnetwork( (NUM_AGENTS * env.observation_space.shape[0],), (NUM_AGENTS * env.action_space.shape[0],), seed ) for _ in range( NUM_AGENTS ) ] criticsNetsTarget = [ Qnetwork( (NUM_AGENTS * env.observation_space.shape[0],), (NUM_AGENTS * env.action_space.shape[0],), seed ) for _ in range( NUM_AGENTS ) ] for _netLocal, _netTarget in zip( criticsNetsLocal, criticsNetsTarget ) : _netTarget.copy( _netLocal ) _netLocal.to( DEVICE ) _netTarget.to( DEVICE ) optimsCritics = [ opt.Adam( _criticNet.parameters(), lr = LEARNING_RATE_CRITIC ) \ for _criticNet in criticsNetsLocal ] # print a brief summary of the network summary( criticsNetsLocal[0], [(NUM_AGENTS * env.observation_space.shape[0],), (NUM_AGENTS * env.action_space.shape[0],)] ) print( criticsNetsLocal[0] ) ##------------------------------------------------------------------------## # Circular Replay buffer rbuffer = ReplayBuffer( REPLAY_BUFFER_SIZE, NUM_AGENTS ) # Noise process noise = OUNoise( env.action_space.shape, seed ) # Noise scaler factor (annealed with a schedule) epsilon = 1.0 progressbar = tqdm( range( 1, num_episodes + 1 ), desc = 'Training>' ) scoresAvgs = [] scoresWindow = deque( maxlen = LOG_WINDOW ) bestScore = -np.inf avgScore = -np.inf from tensorboardX import SummaryWriter writer = SummaryWriter( os.path.join( SESSION_FOLDER, 'tensorboard_summary' ) ) istep = 0 for iepisode in progressbar : noise.reset() _oo = env.reset() _scoreAgents = np.zeros( NUM_AGENTS ) for i in range( MAX_STEPS_IN_EPISODE ) : # take full-random actions during these many steps if istep < TRAINING_STARTING_STEP : _aa = np.clip( np.random.randn( *((NUM_AGENTS,) + env.action_space.shape) ), -1., 1. ) # take actions from exploratory policy else : # eval-mode (in case batchnorm is used) for _actorNet in actorsNetsLocal : _actorNet.eval() # choose an action for each agent using its own actor network with torch.no_grad() : _aa = [] for iactor, _actorNet in enumerate( actorsNetsLocal ) : # evaluate action to take from each actor policy _a = _actorNet( torch.from_numpy( _oo[iactor] ).unsqueeze( 0 ).float().to( DEVICE ) ).cpu().data.numpy().squeeze() _aa.append( _a ) _aa = np.array( _aa ) # add some noise sampled from the noise process (each agent gets different sample) _nn = np.array( [ epsilon * noise.sample() for _ in range( NUM_AGENTS ) ] ).reshape( _aa.shape ) _aa += _nn # actions are speed-factors (range (-1,1)) in both x and y _aa = np.clip( _aa, -1., 1. ) # back to train-mode (in case batchnorm is used) for _actorNet in actorsNetsLocal : _actorNet.train() # take action in the environment and grab bounty _oonext, _rr, _dd, _ = env.step( _aa ) # store joint information (form (NAGENTS,) + MEASUREMENT-SHAPE) if i == MAX_STEPS_IN_EPISODE - 1 : rbuffer.store( ( _oo, _aa, _rr, _oonext, np.ones_like( _dd ) ) ) else : rbuffer.store( ( _oo, _aa, _rr, _oonext, _dd ) ) if len( rbuffer ) > BATCH_SIZE and istep % TRAIN_FREQUENCY_STEPS == 0 and \ istep >= TRAINING_STARTING_STEP : for _ in range( TRAIN_NUM_UPDATES ) : # grab a batch of data from the replay buffer _observations, _actions, _rewards, _observationsNext, _dones = rbuffer.sample( BATCH_SIZE ) # compute joint observations and actions to be passed ... # to the critic, which basically consists of keep the ... # batch dimension and vectorize everything else into one ... # single dimension [o1,...,on] and [a1,...,an] _batchJointObservations = _observations.reshape( _observations.shape[0], -1 ) _batchJointObservationsNext = _observationsNext.reshape( _observationsNext.shape[0], -1 ) _batchJointActions = _actions.reshape( _actions.shape[0], -1 ) # compute the joint next actions required for the centralized ... # critics q-target computation with torch.no_grad() : _batchJointActionsNext = torch.stack( [ actorsNetsTarget[iactor]( _observationsNext[:,iactor,:] ) \ for iactor in range( NUM_AGENTS ) ], dim = 1 ) _batchJointActionsNext = _batchJointActionsNext.reshape( _batchJointActionsNext.shape[0], -1 ) for iactor in range( NUM_AGENTS ) : # extract local observations to be fed to the actors, ... # as well as local rewards and dones to be used for local # q-targets computation using critics _batchLocalObservations = _observations[:,iactor,:] _batchLocalRewards = _rewards[:,iactor,:] _batchLocalDones = _dones[:,iactor,:] #---------------------- TRAIN CRITICS --------------------# # compute current q-values for the joint-actions taken ... # at joint-observations using the critic, as explained ... # in the MADDPG algorithm: # # Q(x,a1,a2,...,an) -> Q( [o1,o2,...,on], [a1,a2,...,an] ) # phi-i _qvalues = criticsNetsLocal[iactor]( _batchJointObservations, _batchJointActions ) # compute target q-values using both decentralized ... # target actor and centralized target critic for this ... # current actor, as explained in the MADDPG algorithm: # # Q-targets = r + ( 1 - done ) * gamma * Q ( [o1',...,on'], [a1',...,an'] ) # i i i phi-target-i # # with torch.no_grad() : _qvaluesTarget = _batchLocalRewards + ( 1. - _batchLocalDones ) \ * GAMMA * criticsNetsTarget[iactor]( _batchJointObservationsNext, _batchJointActionsNext ) # compute loss for the critic optimsCritics[iactor].zero_grad() _lossCritic = F.mse_loss( _qvalues, _qvaluesTarget ) _lossCritic.backward() torch.nn.utils.clip_grad_norm( criticsNetsLocal[iactor].parameters(), 1 ) optimsCritics[iactor].step() #---------------------- TRAIN ACTORS ---------------------# # compute loss for the actor, from the objective to "maximize": # # dJ / dtheta = E [ dQ / du * du / dtheta ] # # where: # * theta: weights of the actor # * dQ / du : gradient of Q w.r.t. u (actions taken) # * du / dtheta : gradient of the Actor's weights optimsActors[iactor].zero_grad() # compute predicted actions for current local observations ... # as we will need them for computing the gradients of the ... # actor. Recall that these gradients depend on the gradients ... # of its own related centralized critic, which need the joint ... # actions to work. Keep with grads here as we have to build ... # the computation graph with these operations _batchJointActionsPred = torch.stack( [ actorsNetsLocal[indexActor]( _observations[:,indexActor,:] ) \ for indexActor in range( NUM_AGENTS ) ], dim = 1 ) _batchJointActionsPred = _batchJointActionsPred.reshape( _batchJointActionsPred.shape[0], -1 ) # compose the critic over the actor outputs (sandwich), which effectively does g(f(x)) _lossActor = -criticsNetsLocal[iactor]( _batchJointObservations, _batchJointActionsPred ).mean() _lossActor.backward() optimsActors[iactor].step() # update target networks actorsNetsTarget[iactor].copy( actorsNetsLocal[iactor], TAU ) criticsNetsTarget[iactor].copy( criticsNetsLocal[iactor], TAU ) # update epsilon using schedule if EPSILON_SCHEDULE == 'linear' : epsilon = max( 0.1, epsilon - EPSILON_DECAY_LINEAR ) else : epsilon = max( 0.1, epsilon * EPSILON_DECAY_FACTOR ) for iactor in range( NUM_AGENTS ) : torch.save( actorsNetsLocal[iactor].state_dict(), os.path.join( SESSION_FOLDER, 'maddpg_actor_reacher_' + str(iactor) + '.pth' ) ) torch.save( criticsNetsLocal[iactor].state_dict(), os.path.join( SESSION_FOLDER, 'maddpg_critic_reacher_' + str(iactor) + '.pth' ) ) # book keeping for next iteration _oo = _oonext _scoreAgents += _rr istep += 1 if _dd.any() : break # update some info for logging _score = np.max( _scoreAgents ) # score of the game is the max over both agents' scores bestScore = max( bestScore, _score ) # max game score so far scoresWindow.append( _score ) if iepisode >= LOG_WINDOW : avgScore = np.mean( scoresWindow ) scoresAvgs.append( avgScore ) message = 'Training> best: %.2f - mean: %.2f - current: %.2f' progressbar.set_description( message % ( bestScore, avgScore, _score ) ) progressbar.refresh() else : message = 'Training> best: %.2f - current : %.2f' progressbar.set_description( message % ( bestScore, _score ) ) progressbar.refresh() writer.add_scalar( 'score', _score, iepisode ) writer.add_scalar( 'avg_score', np.mean( scoresWindow ), iepisode ) writer.add_scalar( 'buffer_size', len( rbuffer ), iepisode ) writer.add_scalar( 'epsilon', epsilon, iepisode ) for iactor in range( NUM_AGENTS ) : torch.save( actorsNetsLocal[iactor].state_dict(), os.path.join( SESSION_FOLDER, 'maddpg_actor_reacher_' + str(iactor) + '.pth' ) ) torch.save( criticsNetsLocal[iactor].state_dict(), os.path.join( SESSION_FOLDER, 'maddpg_critic_reacher_' + str(iactor) + '.pth' ) )
def eval_metrics(pred, target): rmse = torch.sqrt(F.mse_loss(pred, target)) return {'rmse': rmse}
def pretraining(train_path, test_path, lr_index, g_noise_var, pre_num_iter, fine_num_iter, use_filter_data, filter_by_year): stock_name = train_path.split('/')[1].split('_')[0] p_1 = train_path.split('/')[1].split('_')[3].replace('p1', '') p_2 = train_path.split('/')[1].split('_')[4].replace('p2', '') # Load data # Remove timestamp train_data = np.load(train_path)[:, :11] test_data = np.load(test_path)[:, :11] train_data = torch.tensor(train_data).float() test_data = torch.tensor(test_data).float() # Initialize model and optimizer sdae_model = SDAE(11) model_optimizer = optim.Adam(sdae_model.parameters(), lr=10**(-1.0 * lr_index)) # Layer wise pretraining for layer_index in range(sdae_model.num_layers): sdae_model.freeze_all_but(layer_index) for iter in range(pre_num_iter): sample_indices = np.random.randint(train_data.shape[0], size=(BATCH_SIZE, )) batch_input = train_data[sample_indices] batch_output = batch_input.clone() # Add gaussian noise corrupted_batch_input = batch_input + ( g_noise_var**0.5) * torch.randn(batch_input.size()) # Learning pred = sdae_model.forward(batch_input) loss = F.mse_loss(pred, batch_output) model_optimizer.zero_grad() loss.backward() model_optimizer.step() sdae_model.unfreeze_all() # Fine-tuning sdae_model.unfreeze_all() for iter in range(fine_num_iter): sample_indices = np.random.randint(train_data.shape[0], size=(BATCH_SIZE, )) batch_input = train_data[sample_indices] batch_output = batch_input.clone() # Add gaussian noise corrupted_batch_input = batch_input + (g_noise_var**0.5) * torch.randn( batch_input.size()) # Learning pred = sdae_model.forward(batch_input) loss = F.mse_loss(pred, batch_output) model_optimizer.zero_grad() loss.backward() model_optimizer.step() # Testing if ((iter + 1) % EVAL_FREQ == 0): corrupted_test_input = test_data + ( g_noise_var**0.5) * torch.randn(test_data.size()) with torch.no_grad(): pred = sdae_model.forward(corrupted_test_input) loss = F.mse_loss(pred, test_data) print("fine-tune iter: " + str(iter) + " loss: " + str(loss)) sys.stdout.flush() if (use_filter_data): torch.save( sdae_model.state_dict(), MODELS_PATH + stock_name + "_p1" + str(p_1) + "_p2" + str(p_2) + "_sdae_model_lr" + str(lr_index) + "_g_noise_var" + str(g_noise_var) + "_pre" + str(pre_num_iter) + "fine" + str(fine_num_iter) + "_filtered_fyear" + str(filter_by_year) + ".pt") else: torch.save( sdae_model.state_dict(), MODELS_PATH + stock_name + "_p1" + str(p_1) + "_p2" + str(p_2) + "_sdae_model_lr" + str(lr_index) + "_g_noise_var" + str(g_noise_var) + "_pre" + str(pre_num_iter) + "fine" + str(fine_num_iter) + ".pt")
def rmse(y_pred, target) -> Tensor: return sqrt(F.mse_loss(y_pred, target))
def get_tv_loss(y_): dx = F.mse_loss(y_[:, :, :, 1:], y_[:, :, :, :-1]) dy = F.mse_loss(y_[:, :, 1:, :], y_[:, :, :-1, :]) return dx + dy
def train(self, train_buffer, val_buffer, callback_fn): for epoch in range(self.args['max_epoch']): for i in range(self.args['steps_per_epoch']): batch_data = train_buffer.sample(self.batch_size) batch_data.to_torch(device=self.device) obs = batch_data['obs'] action = batch_data['act'] next_obs = batch_data['obs_next'] reward = batch_data['rew'] done = batch_data['done'].float() # train vae dist, _action = self.vae(obs, action) kl_loss = kl_divergence(dist, Normal(0, 1)).sum(dim=-1).mean() recon_loss = ((action - _action)**2).sum(dim=-1).mean() vae_loss = kl_loss + recon_loss self.vae_optim.zero_grad() vae_loss.backward() self.vae_optim.step() # train critic with torch.no_grad(): repeat_next_obs = torch.repeat_interleave( next_obs.unsqueeze(0), 10, 0) multiple_actions = self.jitter_target( repeat_next_obs, self.vae.decode(repeat_next_obs)) obs_action = torch.cat([repeat_next_obs, multiple_actions], dim=-1) target_q1 = self.target_q1(obs_action) target_q2 = self.target_q2(obs_action) target_q = self.lam * torch.min(target_q1, target_q2) + ( 1 - self.lam) * torch.max(target_q1, target_q2) target_q = torch.max(target_q, dim=0)[0] target_q = reward + self.gamma * (1 - done) * target_q obs_action = torch.cat([obs, action], dim=-1) q1 = self.q1(obs_action) q2 = self.q2(obs_action) critic_loss = F.mse_loss(q1, target_q) + F.mse_loss( q2, target_q) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() # train jitter action = self.vae.decode(obs) action = self.jitter(obs, action) obs_action = torch.cat([obs, action], dim=-1) jitter_loss = -self.q1(obs_action).mean() self.jitter_optim.zero_grad() jitter_loss.backward() self.jitter_optim.step() # soft target update self._sync_weight(self.jitter_target, self.jitter, soft_target_tau=self.args['soft_target_tau']) self._sync_weight(self.target_q1, self.q1, soft_target_tau=self.args['soft_target_tau']) self._sync_weight(self.target_q2, self.q2, soft_target_tau=self.args['soft_target_tau']) res = callback_fn(self.get_policy()) res['kl_loss'] = kl_loss.item() self.log_res(epoch, res) return self.get_policy()
def update(self, rollouts: ReplayBuffer, iter_idx, device="cpu", callback=None): def compute_pi_loss(): pass def compute_val_loss(self): pass if rollouts.__len__() <= 0: return sps_dict = rollouts.sample(batch_size='all') nn = self.actor_critic optimizer = self.optimizer total_loss = 0 total_rewards = 0 num_sps = 0 # for key in sps_dict: # if sps_dict[key]: # num_sps += len(sps_dict[key].actions) # # print(num_sps) # print(num_sps) # input() for key in sps_dict: if key not in nn.activated_agents: continue if sps_dict[key]: states, units, actions, next_states, rewards, hxses, done_masks, durations, ctfs, irews = sps_dict[ key].to(device) # rets = discount_cumsum(rewards, self.gamma) # rets = discount_cumsum_(rewards, self.gamma, durations) if self.actor_critic.recurrent: value, probs, _ = nn.forward(actor_type=key, spatial_feature=states, unit_feature=units, hxs=hxses.unsqueeze(0)) else: value, probs, _ = nn.forward(actor_type=key, spatial_feature=states, unit_feature=units) # print(value) m = torch.distributions.Categorical(probs=probs) # print(probs.is_leaf) # input() # entropy = - (probs * torch.log(probs)).mean(dim=1) # entropy = - (probs * torch.log(probs)).sum() value_next = nn.critic_forward(next_states).detach() # probs_next, _ = nn.actor_forward(actor_type=key,spatial_feature=states,unit_feature=units) # m = torch.distributions.Categorical(probs=probs_next) # rewards = rewards + m.entropy().unsqueeze(0) pi_sa = probs.gather(1, actions) # pi_sa.retain_grad() # rewards = (rewards - rewards.mean())/rewards.std() targets = rewards + (self.gamma** durations) * (value_next * done_masks) # targets = rets # for bat in states: # ratio = [] # for i in range(len(rewards)): # if rewards[i] == 0: # ratio.append([1]) # else: # ratio.append([irews[i]/rewards[i]]) # ratio = torch.Tensor(ratio) # print(ratio) advantages = targets - value # advantages = rets - value # advantages = rewards[:-1] + self.gamma ** durations * value[1:] - value.detach() # adv = (advantages - advantages.mean()) / advantages.std() # print(adv) adv = advantages.detach() # print(adv) # print(m.entropy()) # input() entropy_loss = -m.entropy().mean() policy_loss = -(torch.log(pi_sa) * adv).mean() # print(len(rewards)) # input() value_loss = F.mse_loss(value, targets) # value_loss = value_criteria(targets, rets * done_masks) all_loss = policy_loss + value_loss * self.value_loss_coef + self.entropy_coef * entropy_loss # print(len(actions)) total_loss += all_loss total_rewards += rewards.mean() # all_loss = value_loss * self.value_loss_coef + policy_loss - dist_entropy * self.entropy_coef optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), .5) optimizer.step() results = { # "p_loss": policy_loss.mean(), # "v_loss": value_loss.mean(), "rewards": total_rewards.mean(), # "entropy_loss": entropy_loss, "all_losses": all_loss, } if iter_idx % self.log_interval == 0: if callback: callback(iter_idx, results) rollouts.refresh()