def batch_update_weights(optimizer: optim.Optimizer, network: Network, batch): optimizer.zero_grad() value_loss = 0 reward_loss = 0 policy_loss = 0 # Format training data image_batch = np.array([item[0] for item in batch]) action_batches = np.array([item[1] for item in batch]) target_batches = np.array([item[2] for item in batch]) action_batches = np.swapaxes(action_batches, 0, 1) target_batches = target_batches.transpose(1, 2, 0) # Run initial inference values, rewards, policy_logits, hidden_states = network.batch_initial_inference( image_batch) predictions = [(1, values, rewards, policy_logits)] # Run recurrent inferences for action_batch in action_batches: values, rewards, policy_logits, hidden_states = network.batch_recurrent_inference( hidden_states, action_batch) predictions.append( (1.0 / len(action_batches), values, rewards, policy_logits)) hidden_states = scale_gradient(hidden_states, 0.5) # Calculate losses for target_batch, prediction_batch in zip(target_batches, predictions): gradient_scale, values, rewards, policy_logits = prediction_batch target_values, target_rewards, target_policies = \ (torch.tensor(list(item), dtype=torch.float32, device=values.device.type) \ for item in target_batch) gradient_scale = torch.tensor(gradient_scale, dtype=torch.float32, device=values.device.type) value_loss += gradient_scale * scalar_loss(values, target_values) reward_loss += gradient_scale * scalar_loss(rewards, target_rewards) policy_loss += gradient_scale * cross_entropy_with_logits( policy_logits, target_policies, dim=1) value_loss = value_loss.mean() / len(batch) reward_loss = reward_loss.mean() / len(batch) policy_loss = policy_loss.mean() / len(batch) total_loss = value_loss + reward_loss + policy_loss logging.info('Training step {} losses'.format(network.training_steps()) + \ ' | Total: {:.5f}'.format(total_loss) + \ ' | Value: {:.5f}'.format(value_loss) + \ ' | Reward: {:.5f}'.format(reward_loss) + \ ' | Policy: {:.5f}'.format(policy_loss)) # Update weights total_loss.backward() optimizer.step() network.increment_step() return total_loss, value_loss, reward_loss, policy_loss
def update_weights(optimizer: optim.Optimizer, network: Network, batch): optimizer.zero_grad() value_loss = 0 reward_loss = 0 policy_loss = 0 for image, actions, targets in batch: # Initial step, from the real observation. value, reward, policy_logits, hidden_state = network.initial_inference( image) predictions = [(1.0 / len(batch), value, reward, policy_logits)] # Recurrent steps, from action and previous hidden state. for action in actions: value, reward, policy_logits, hidden_state = network.recurrent_inference( hidden_state, action) # TODO: Try not scaling this for efficiency # Scale so total recurrent inference updates have the same weight as the on initial inference update predictions.append( (1.0 / len(actions), value, reward, policy_logits)) hidden_state = scale_gradient(hidden_state, 0.5) for prediction, target in zip(predictions, targets): gradient_scale, value, reward, policy_logits = prediction target_value, target_reward, target_policy = \ (torch.tensor(item, dtype=torch.float32, device=value.device.type) \ for item in target) # Past end of the episode if len(target_policy) == 0: break value_loss += gradient_scale * scalar_loss(value, target_value) reward_loss += gradient_scale * scalar_loss(reward, target_reward) policy_loss += gradient_scale * cross_entropy_with_logits( policy_logits, target_policy) # print('val -------', value, target_value, scalar_loss(value, target_value)) # print('rew -------', reward, target_reward, scalar_loss(reward, target_reward)) # print('pol -------', policy_logits, target_policy, cross_entropy_with_logits(policy_logits, target_policy)) value_loss /= len(batch) reward_loss /= len(batch) policy_loss /= len(batch) total_loss = value_loss + reward_loss + policy_loss scaled_loss = scale_gradient(total_loss, gradient_scale) logging.info('Training step {} losses'.format(network.training_steps()) + \ ' | Total: {:.5f}'.format(total_loss) + \ ' | Value: {:.5f}'.format(value_loss) + \ ' | Reward: {:.5f}'.format(reward_loss) + \ ' | Policy: {:.5f}'.format(policy_loss)) scaled_loss.backward() optimizer.step() network.increment_step()