Beispiel #1
0
    def __init__(
        self,
        trainer_factory: TrainerFactory,
        output_path: str,
        run_id: str,
        param_manager: EnvironmentParameterManager,
        train: bool,
        training_seed: int,
    ):
        """
        :param output_path: Path to save the model.
        :param summaries_dir: Folder to save training summaries.
        :param run_id: The sub-directory name for model and summary statistics
        :param param_manager: EnvironmentParameterManager object which stores information about all
        environment parameters.
        :param train: Whether to train model, or only run inference.
        :param training_seed: Seed to use for Numpy and Tensorflow random number generation.
        :param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
        """
        self.trainers: Dict[str, Trainer] = {}
        self.brain_name_to_identifier: Dict[str, Set] = defaultdict(set)
        self.trainer_factory = trainer_factory
        self.output_path = output_path
        self.logger = get_logger(__name__)
        self.run_id = run_id
        self.train_model = train
        self.param_manager = param_manager
        self.ghost_controller = self.trainer_factory.ghost_controller
        self.registered_behavior_ids: Set[str] = set()

        self.trainer_threads: List[threading.Thread] = []
        self.kill_trainers = False
        np.random.seed(training_seed)
        tf.set_random_seed(training_seed)
        self.rank = get_rank()
Beispiel #2
0
    def __init__(
        self,
        seed: int,
        behavior_spec: BehaviorSpec,
        trainer_settings: TrainerSettings,
        tanh_squash: bool = False,
        reparameterize: bool = False,
        condition_sigma_on_obs: bool = True,
        create_tf_graph: bool = True,
    ):
        """
        Initialized the policy.
        :param seed: Random seed to use for TensorFlow.
        :param brain: The corresponding Brain for this policy.
        :param trainer_settings: The trainer parameters.
        """
        super().__init__(
            seed,
            behavior_spec,
            trainer_settings,
            tanh_squash,
            reparameterize,
            condition_sigma_on_obs,
        )
        if (
            self.behavior_spec.action_spec.continuous_size > 0
            and self.behavior_spec.action_spec.discrete_size > 0
        ):
            raise UnityPolicyException(
                "TensorFlow does not support mixed action spaces. Please run with the Torch framework."
            )
        # for ghost trainer save/load snapshots
        self.assign_phs: List[tf.Tensor] = []
        self.assign_ops: List[tf.Operation] = []
        self.update_dict: Dict[str, tf.Tensor] = {}
        self.inference_dict: Dict[str, tf.Tensor] = {}
        self.first_normalization_update: bool = False

        self.graph = tf.Graph()
        self.sess = tf.Session(
            config=tf_utils.generate_session_config(), graph=self.graph
        )
        self._initialize_tensorflow_references()
        self.grads = None
        self.update_batch: Optional[tf.Operation] = None
        self.trainable_variables: List[tf.Variable] = []
        self.rank = get_rank()
        if create_tf_graph:
            self.create_tf_graph()
Beispiel #3
0
 def __init__(self):
     self.training_start_time = time.time()
     # If self-play, we want to print ELO as well as reward
     self.self_play = False
     self.self_play_team = -1
     self.rank = get_rank()