Exemple #1
0
    def add_policy(self, parsed_behavior_id: BehaviorIdentifiers,
                   policy: TFPolicy) -> None:
        """
        Adds policy to trainer. The first policy encountered sets the wrapped
        trainer team.  This is to ensure that all agents from the same multi-agent
        team are grouped. All policies associated with this team are added to the
        wrapped trainer to be trained.
        :param name_behavior_id: Behavior ID that the policy should belong to.
        :param policy: Policy to associate with name_behavior_id.
        """
        name_behavior_id = parsed_behavior_id.behavior_id
        team_id = parsed_behavior_id.team_id
        self.controller.subscribe_team_id(team_id, self)
        self.policies[name_behavior_id] = policy
        policy.create_tf_graph()

        self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
        # for saving/swapping snapshots
        policy.init_load_weights()

        # First policy or a new agent on the same team encountered
        if self.wrapped_trainer_team is None or team_id == self.wrapped_trainer_team:
            self.current_policy_snapshot[
                parsed_behavior_id.brain_name] = policy.get_weights()

            self._save_snapshot(
            )  # Need to save after trainer initializes policy
            self.trainer.add_policy(parsed_behavior_id, policy)
            self._learning_team = self.controller.get_learning_team
            self.wrapped_trainer_team = team_id
Exemple #2
0
    def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
        """
        Adds policy to trainer. For the first policy added, add a trainer
        to the policy and set the learning behavior name to name_behavior_id.
        :param name_behavior_id: Behavior ID that the policy should belong to.
        :param policy: Policy to associate with name_behavior_id.
        """
        self.policies[name_behavior_id] = policy
        policy.create_tf_graph()

        # First policy encountered
        if not self.learning_behavior_name:
            weights = policy.get_weights()
            self.current_policy_snapshot = weights
            self.trainer.add_policy(name_behavior_id, policy)
            self._save_snapshot(
                policy)  # Need to save after trainer initializes policy
            self.learning_behavior_name = name_behavior_id
            behavior_id_parsed = BehaviorIdentifiers.from_name_behavior_id(
                self.learning_behavior_name)
            team_id = behavior_id_parsed.behavior_ids["team"]
            self._stats_reporter.add_property(StatsPropertyType.SELF_PLAY_TEAM,
                                              team_id)
        else:
            # for saving/swapping snapshots
            policy.init_load_weights()
Exemple #3
0
 def _save_snapshot(self, policy: TFPolicy) -> None:
     weights = policy.get_weights()
     try:
         self.policy_snapshots[self.snapshot_counter] = weights
     except IndexError:
         self.policy_snapshots.append(weights)
     self.policy_elos[self.snapshot_counter] = self.current_elo
     self.snapshot_counter = (self.snapshot_counter + 1) % self.window
Exemple #4
0
    def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
        """
        Adds policy to trainer. For the first policy added, add a trainer
        to the policy and set the learning behavior name to name_behavior_id.
        :param name_behavior_id: Behavior ID that the policy should belong to.
        :param policy: Policy to associate with name_behavior_id.
        """
        self.policies[name_behavior_id] = policy
        policy.create_tf_graph()

        # First policy encountered
        if not self.learning_behavior_name:
            weights = policy.get_weights()
            self.current_policy_snapshot = weights
            self.trainer.add_policy(name_behavior_id, policy)
            self._save_snapshot(policy)  # Need to save after trainer initializes policy
            self.learning_behavior_name = name_behavior_id
        else:
            # for saving/swapping snapshots
            policy.init_load_weights()