def update_weights(optimizer: optim.Optimizer, network: Network, batch): optimizer.zero_grad() value_loss = 0 reward_loss = 0 policy_loss = 0 for image, actions, targets in batch: # Initial step, from the real observation. value, reward, policy_logits, hidden_state = network.initial_inference( image) predictions = [(1.0 / len(batch), value, reward, policy_logits)] # Recurrent steps, from action and previous hidden state. for action in actions: value, reward, policy_logits, hidden_state = network.recurrent_inference( hidden_state, action) # TODO: Try not scaling this for efficiency # Scale so total recurrent inference updates have the same weight as the on initial inference update predictions.append( (1.0 / len(actions), value, reward, policy_logits)) hidden_state = scale_gradient(hidden_state, 0.5) for prediction, target in zip(predictions, targets): gradient_scale, value, reward, policy_logits = prediction target_value, target_reward, target_policy = \ (torch.tensor(item, dtype=torch.float32, device=value.device.type) \ for item in target) # Past end of the episode if len(target_policy) == 0: break value_loss += gradient_scale * scalar_loss(value, target_value) reward_loss += gradient_scale * scalar_loss(reward, target_reward) policy_loss += gradient_scale * cross_entropy_with_logits( policy_logits, target_policy) # print('val -------', value, target_value, scalar_loss(value, target_value)) # print('rew -------', reward, target_reward, scalar_loss(reward, target_reward)) # print('pol -------', policy_logits, target_policy, cross_entropy_with_logits(policy_logits, target_policy)) value_loss /= len(batch) reward_loss /= len(batch) policy_loss /= len(batch) total_loss = value_loss + reward_loss + policy_loss scaled_loss = scale_gradient(total_loss, gradient_scale) logging.info('Training step {} losses'.format(network.training_steps()) + \ ' | Total: {:.5f}'.format(total_loss) + \ ' | Value: {:.5f}'.format(value_loss) + \ ' | Reward: {:.5f}'.format(reward_loss) + \ ' | Policy: {:.5f}'.format(policy_loss)) scaled_loss.backward() optimizer.step() network.increment_step()
def play_game(config: MuZeroConfig, network: Network) -> Game: game = Game.from_config(config) while not game.terminal() and len(game.history) < config.max_moves: # At the root of the search tree we use the representation function to # obtain a hidden state given the current observation. root = Node(0) last_observation = game.make_image(-1) root.expand(game.to_play(), game.legal_actions(), network.initial_inference(last_observation).numpy()) root.add_exploration_noise(config) # logging.debug('Running MCTS on step {}.'.format(len(game.history))) # We then run a Monte Carlo Tree Search using only action sequences and the # model learned by the network. run_mcts(config, root, game.action_history(), network) action = root.select_action(config, len(game.history), network) game.apply(action) game.store_search_statistics(root) logging.info('Finished episode at step {} | cumulative reward: {}' \ .format(len(game.obs_history), sum(game.rewards))) return game