def learn(actor_model, model, batch, initial_agent_state, optimizer, scheduler, flags, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: intrinsic_rewards = torch.ones( (flags.unroll_length, flags.batch_size), dtype=torch.float32).to(device=flags.device) intrinsic_rewards = batch['train_state_count'][1:].float().to( device=flags.device) intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss( learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss episode_returns = batch['episode_return'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), } scheduler.step() optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn(actor_model, model, batch, initial_agent_state, optimizer, scheduler, flags, action_hist, position_count, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: intrinsic_rewards = torch.ones( (flags.unroll_length, flags.batch_size), dtype=torch.float32).to(device=flags.device) # Store position id position_coord, position_counts = torch.unique( batch["agent_position"].view(-1, 2), return_counts=True, dim=0) for i, coord in enumerate(position_coord): coord = tuple(coord[:].cpu().numpy()) position_count[coord] = position_count.get( coord, 0) + position_counts[i].item() # Saving action histogram, to visualize but NOT nudging it action_id, count_action = torch.unique(batch["action"].flatten(), return_counts=True) acted_id, action_acted = torch.unique( (batch["action"] + 1) * batch["action_acted"] - 1, return_counts=True) action_hist.add(action_id, count_action.cpu().float(), acted_id[1:].cpu(), action_acted[1:].cpu().float()) intrinsic_rewards = batch['train_state_count'][1:].float().to( device=flags.device) intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss( learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss episode_returns = batch['episode_return'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), } scheduler.step() optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn(actor_model, model, state_embedding_model, forward_dynamics_model, inverse_dynamics_model, batch, initial_agent_state, optimizer, state_embedding_optimizer, forward_dynamics_optimizer, inverse_dynamics_optimizer, scheduler, flags, frames=None, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: state_emb = state_embedding_model(batch['frame'][:-1].to(device=flags.device)) next_state_emb = state_embedding_model(batch['frame'][1:].to(device=flags.device)) pred_next_state_emb = forward_dynamics_model(\ state_emb, batch['action'][1:].to(device=flags.device)) pred_actions = inverse_dynamics_model(state_emb, next_state_emb) entropy_emb_actions = losses.compute_entropy_loss(pred_actions) intrinsic_rewards = torch.norm(pred_next_state_emb - next_state_emb, dim=2, p=2) intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef forward_dynamics_loss = flags.forward_loss_coef * \ losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb) inverse_dynamics_loss = flags.inverse_loss_coef * \ losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:]) num_samples = flags.unroll_length * flags.batch_size actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy() intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy() learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } actions = batch['action'].reshape(flags.unroll_length * flags.batch_size).cpu().numpy() action_percentage = [0 for _ in range(model.num_actions)] for i in range(model.num_actions): action_percentage[i] = np.sum([a == i for a in actions]) / len(actions) rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss \ + forward_dynamics_loss + inverse_dynamics_loss episode_returns = batch['episode_return'][batch['done']] episode_lengths = batch['episode_step'][batch['done']] episode_wins = batch['episode_win'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'forward_dynamics_loss': forward_dynamics_loss.item(), 'inverse_dynamics_loss': inverse_dynamics_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), } scheduler.step() optimizer.zero_grad() state_embedding_optimizer.zero_grad() forward_dynamics_optimizer.zero_grad() inverse_dynamics_optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm) optimizer.step() state_embedding_optimizer.step() forward_dynamics_optimizer.step() inverse_dynamics_optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn(actor_model, model, random_target_network, predictor_network, batch, initial_agent_state, optimizer, predictor_optimizer, scheduler, flags, action_hist, position_count, frames=None, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: if flags.use_fullobs_intrinsic: random_embedding = random_target_network(batch, next_state=True)\ .reshape(flags.unroll_length, flags.batch_size, 128) predicted_embedding = predictor_network(batch, next_state=True)\ .reshape(flags.unroll_length, flags.batch_size, 128) else: random_embedding = random_target_network( batch['partial_obs'][1:].to(device=flags.device)) predicted_embedding = predictor_network( batch['partial_obs'][1:].to(device=flags.device)) # Saving position count, to produce heatmaps later position_coord, position_counts = torch.unique( batch["agent_position"].view(-1, 2), return_counts=True, dim=0) for i, coord in enumerate(position_coord): coord = tuple(coord[:].cpu().numpy()) position_count[coord] = position_count.get( coord, 0) + position_counts[i].item() # Saving action histogram, to visualize but NOT nudging it action_id, count_action = torch.unique(batch["action"].flatten(), return_counts=True) acted_id, action_acted = torch.unique( (batch["action"] + 1) * batch["action_acted"] - 1, return_counts=True) action_hist.add(action_id, count_action.cpu().float(), acted_id[1:].cpu(), action_acted[1:].cpu().float()) intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2) intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef num_samples = flags.unroll_length * flags.batch_size actions_flat = batch['action'][1:].reshape( num_samples).cpu().detach().numpy() intrinsic_rewards_flat = intrinsic_rewards.reshape( num_samples).cpu().detach().numpy() rnd_loss = flags.rnd_loss_coef * \ losses.compute_forward_dynamics_loss(predicted_embedding, random_embedding.detach()) learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss( learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss episode_returns = batch['episode_return'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'rnd_loss': rnd_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), } scheduler.step() optimizer.zero_grad() predictor_optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm) optimizer.step() predictor_optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn(actor_model, model, random_target_network, predictor_network, batch, initial_agent_state, optimizer, predictor_optimizer, scheduler, flags, frames=None, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: random_embedding = random_target_network(batch['frame'][1:].to(device=flags.device)) predicted_embedding = predictor_network(batch['frame'][1:].to(device=flags.device)) intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2) intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef num_samples = flags.unroll_length * flags.batch_size actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy() intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy() rnd_loss = flags.rnd_loss_coef * \ losses.compute_forward_dynamics_loss(predicted_embedding, random_embedding.detach()) learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss episode_returns = batch['episode_return'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'rnd_loss': rnd_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), } scheduler.step() optimizer.zero_grad() predictor_optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm) optimizer.step() predictor_optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn(actor_model, model, state_embedding_model, forward_dynamics_model, inverse_dynamics_model, batch, initial_agent_state, optimizer, state_embedding_optimizer, forward_dynamics_optimizer, inverse_dynamics_optimizer, scheduler, flags, frames=None, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: count_rewards = torch.ones((flags.unroll_length, flags.batch_size), dtype=torch.float32).to(device=flags.device) count_rewards = batch['episode_state_count'][1:].float().to( device=flags.device) if flags.use_fullobs_intrinsic: state_emb = state_embedding_model(batch, next_state=False)\ .reshape(flags.unroll_length, flags.batch_size, 128) next_state_emb = state_embedding_model(batch, next_state=True)\ .reshape(flags.unroll_length, flags.batch_size, 128) else: state_emb = state_embedding_model( batch['partial_obs'][:-1].to(device=flags.device)) next_state_emb = state_embedding_model( batch['partial_obs'][1:].to(device=flags.device)) pred_next_state_emb = forward_dynamics_model( state_emb, batch['action'][1:].to(device=flags.device)) pred_actions = inverse_dynamics_model(state_emb, next_state_emb) control_rewards = torch.norm(next_state_emb - state_emb, dim=2, p=2) intrinsic_rewards = count_rewards * control_rewards intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef forward_dynamics_loss = flags.forward_loss_coef * \ losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb) inverse_dynamics_loss = flags.inverse_loss_coef * \ losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:]) learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss( learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss + \ forward_dynamics_loss + inverse_dynamics_loss episode_returns = batch['episode_return'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), 'mean_control_rewards': torch.mean(control_rewards).item(), 'mean_count_rewards': torch.mean(count_rewards).item(), 'forward_dynamics_loss': forward_dynamics_loss.item(), 'inverse_dynamics_loss': inverse_dynamics_loss.item(), } scheduler.step() optimizer.zero_grad() state_embedding_optimizer.zero_grad() forward_dynamics_optimizer.zero_grad() inverse_dynamics_optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm) optimizer.step() state_embedding_optimizer.step() forward_dynamics_optimizer.step() inverse_dynamics_optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats
def learn( actor_model, model, action_hist, batch, initial_agent_state, optimizer, # action_distribution_optimizer, scheduler, flags, frames=None, position_count=None, state_embedding_model=None, forward_dynamics_model=None, inverse_dynamics_model=None, state_embedding_optimizer=None, forward_dynamics_optimizer=None, inverse_dynamics_optimizer=None, lock=threading.Lock()): """Performs a learning (optimization) step.""" with lock: count_rewards = torch.ones((flags.unroll_length, flags.batch_size), dtype=torch.float32).to(device=flags.device) count_rewards = batch['episode_state_count'][1:].float().to( device=flags.device) action_id, count_action = torch.unique(batch["action"].flatten(), return_counts=True) # Store position id position_coord, position_counts = torch.unique( batch["agent_position"].view(-1, 2), return_counts=True, dim=0) for i, coord in enumerate(position_coord): coord = tuple(coord[:].cpu().numpy()) position_count[coord] = position_count.get( coord, 0) + position_counts[i].item() if state_embedding_model: current_state_embedding = state_embedding_model( batch['partial_obs'][:-1].to(device=flags.device)) next_state_embedding = state_embedding_model( batch['partial_obs'][1:].to(device=flags.device)) batch_action_acted = torch.abs(current_state_embedding - next_state_embedding).sum( dim=2) > flags.change_treshold acted_id, action_acted = torch.unique( (batch["action"][:-1] + 1) * batch_action_acted.long() - 1, return_counts=True) pred_next_state_emb = forward_dynamics_model( current_state_embedding, batch['action'][1:].to(device=flags.device)) pred_actions = inverse_dynamics_model(current_state_embedding, next_state_embedding) forward_dynamics_loss = flags.forward_loss_coef * \ losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_embedding) inverse_dynamics_loss = flags.inverse_loss_coef * \ losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:]) else: acted_id, action_acted = torch.unique( (batch["action"] + 1) * batch["action_acted"] - 1, return_counts=True) batch_action_acted = batch["action_acted"][:-1].byte() forward_dynamics_loss = torch.zeros(1) inverse_dynamics_loss = torch.zeros(1) action_hist.add(action_id, count_action.cpu().float(), acted_id[1:].cpu(), action_acted[1:].cpu().float()) action_rewards = torch.zeros_like(batch["action"][:-1]).float() acted_ratio = action_hist.usage_ratio() if flags.action_dist_decay_coef == 0: reward_for_an_action = 1 - acted_ratio else: reward_for_an_action = torch.exp(-acted_ratio * flags.action_dist_decay_coef) reward_for_an_action[acted_ratio == 1] = 0 reward_for_an_action[torch.isnan(reward_for_an_action)] = 0 assert torch.all(reward_for_an_action >= 0 ), "Problem, reward should only be positive" for id in action_id: action_rewards[(batch["action"][:-1] == id.item()) & batch_action_acted] = reward_for_an_action[id] intrinsic_rewards = count_rewards * action_rewards intrinsic_reward_coef = flags.intrinsic_reward_coef intrinsic_rewards *= intrinsic_reward_coef learner_outputs, unused_state = model(batch, initial_agent_state) bootstrap_value = learner_outputs['baseline'][-1] batch = {key: tensor[1:] for key, tensor in batch.items()} learner_outputs = { key: tensor[:-1] for key, tensor in learner_outputs.items() } rewards = batch['reward'] if flags.no_reward: total_rewards = intrinsic_rewards else: total_rewards = rewards + intrinsic_rewards clipped_rewards = torch.clamp(total_rewards, -1, 1) discounts = (~batch['done']).float() * flags.discounting vtrace_returns = vtrace.from_logits( behavior_policy_logits=batch['policy_logits'], target_policy_logits=learner_outputs['policy_logits'], actions=batch['action'], discounts=discounts, rewards=clipped_rewards, values=learner_outputs['baseline'], bootstrap_value=bootstrap_value) pg_loss = losses.compute_policy_gradient_loss( learner_outputs['policy_logits'], batch['action'], vtrace_returns.pg_advantages) baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( vtrace_returns.vs - learner_outputs['baseline']) entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( learner_outputs['policy_logits']) total_loss = pg_loss + baseline_loss + entropy_loss #+ act_distrib_loss episode_returns = batch['episode_return'][batch['done']] stats = { 'mean_episode_return': torch.mean(episode_returns).item(), 'total_loss': total_loss.item(), 'pg_loss': pg_loss.item(), 'baseline_loss': baseline_loss.item(), 'entropy_loss': entropy_loss.item(), 'mean_rewards': torch.mean(rewards).item(), 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 'mean_total_rewards': torch.mean(total_rewards).item(), 'mean_action_rewards': torch.mean(action_rewards).item(), 'mean_count_rewards': torch.mean(count_rewards).item(), 'forward_dynamics_loss': forward_dynamics_loss.item(), 'inverse_dynamics_loss': inverse_dynamics_loss.item(), } scheduler.step() optimizer.zero_grad() if state_embedding_model: state_embedding_optimizer.zero_grad() forward_dynamics_optimizer.zero_grad() inverse_dynamics_optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) if state_embedding_model: nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm) nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm) optimizer.step() if state_embedding_model: state_embedding_optimizer.step() forward_dynamics_optimizer.step() inverse_dynamics_optimizer.step() actor_model.load_state_dict(model.state_dict()) return stats