Exemplo n.º 1
0
def replay_game(game_data: Dict, game_arguments: game_args.GameArgs,
                game_server: server.ServerSocket,
                replay_arguments: replay_args.ReplayArgs):
    """Replays a game's actions."""
    actions = game_data['actions']

    # Start the game
    game = unity_game.UnityGame(game_arguments,
                                game_server,
                                auto_end_turn=False,
                                seed=game_data['seed'],
                                number_of_cards=game_data['num_cards'])
    prev_time: datetime = datetime.strptime(actions[0]['time'], DATE_FORMAT)

    for action in actions:
        if replay_arguments.is_realtime():
            action_time: datetime = datetime.strptime(action['time'],
                                                      DATE_FORMAT)
            sleep_time = replay_arguments.get_playback_speed() * (
                action_time - prev_time).total_seconds()
            if sleep_time > 0:
                sleep_time = min(sleep_time, 0.5)
                time.sleep(sleep_time)
            prev_time = action_time

        if action['type'] == 'instruction':
            game.send_command(action['instruction'])
        elif action['type'] == 'movement':
            if action['move_id'] >= 0:
                if action['character'] == 'Leader':
                    #if action['character'] == 'Human':
                    game.execute_leader_action(
                        agent_actions.AgentAction(action['action']))
                elif action['character'] == 'Follower':
                    #elif action['character'] == 'Agent':
                    game.execute_follower_action(
                        agent_actions.AgentAction(action['action']))
                else:
                    #raise ValueError('Unrecognized player type: ' + action['player'])
                    raise ValueError('Unrecognized player type: ' +
                                     action['character'])
        elif action['type'] == 'finish command':
            game.execute_follower_action(agent_actions.AgentAction.STOP)
        elif action['type'] == 'end turn':
            game.end_turn()
    return game.get_score()
Exemplo n.º 2
0
    def __init__(self, json_object: Dict[str, Any]):
        super(MovementAction, self).__init__(json_object['time'])

        self._agent: environment_objects.ObjectType = environment_objects.ObjectType.LEADER if json_object[
            'character'] == 'Human' else environment_objects.ObjectType.FOLLOWER
        self._action: agent_actions.AgentAction = agent_actions.AgentAction(
            json_object['action'])

        self._card_result = json_object['card_result']
        self._set_result = json_object['set_result']

        self._prior_game_info: state_delta.StateDelta = None
        self._posterior_game_info: state_delta.StateDelta = None
Exemplo n.º 3
0
    def __init__(self, json_object: Dict[str, Any],
                 data_arguments: data_args.DataArgs):
        super(InstructionAction, self).__init__(json_object['time'])

        self._instruction_index: int = json_object['instruction_id']
        self._instruction: str = json_object['instruction']
        self._tokenized_instruction: List[str] = tokenize(
            self._instruction, data_arguments.case_sensitive())
        self._completed: bool = json_object['completed']

        self._agent_aligned_actions: List[agent_actions.AgentAction]
        if self._completed:
            self._agent_aligned_actions = \
                [agent_actions.AgentAction(action) for action in json_object['aligned_actions'] if action != 'initial']
    def loss(
        self, examples: List[instruction_example.InstructionExample]
    ) -> Tuple[torch.Tensor, Any]:
        if self._parallelized:
            inputs = self._model.module.batch_inputs(examples)
        else:
            inputs = self._model.batch_inputs(examples)

        # Scores are size B x T x A, where A is the total number of possible actions.
        scores, auxiliaries = self._model(*inputs)

        auxiliary_losses: Dict[auxiliary.Auxiliary, Any] = dict()

        token_neglogprobs = -nn.functional.log_softmax(scores, dim=2)

        # Scores are size B x E x E, where E is the environment width/depth.
        losses: List[torch.Tensor] = []
        observation_index = 0
        for i, example in enumerate(examples):
            for j, action in enumerate(example.get_action_sequence()):
                step_scores = token_neglogprobs[i][j]
                action_score: torch.Tensor = step_scores[
                    agent_actions.AGENT_ACTIONS.index(
                        agent_actions.AgentAction(action))]
                losses.append(action_score)

            if self._end_to_end:
                if self.get_arguments().get_state_rep_args(
                ).full_observability():
                    plan_losses.compute_per_example_auxiliary_losses(
                        example, i, auxiliaries, list(self._auxiliaries),
                        auxiliary_losses,
                        self._args.get_decoder_args(
                        ).weight_trajectory_by_time(), True)
                else:
                    # Go through all the observations and make sure the predictions for each were correct.
                    # The final average will be over all observations equally (i.e., not reweighted by sequence length)
                    for observation in example.get_partial_observations():
                        plan_losses.compute_per_example_auxiliary_losses(
                            example, observation_index, auxiliaries,
                            list(self._auxiliaries.keys()), auxiliary_losses,
                            self._args.get_decoder_args().
                            weight_trajectory_by_time(), False, observation)
                        observation_index += 1

        for auxiliary_name in self._auxiliaries:
            auxiliary_losses[auxiliary_name] = torch.mean(
                torch.stack(tuple(auxiliary_losses[auxiliary_name])))

        return torch.mean(torch.stack(tuple(losses))), auxiliary_losses