예제 #1
0
파일: learn.py 프로젝트: donlee90/ml-agents
def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None:
    """
    Launches training session.
    :param run_seed: Random seed used for training.
    :param num_areas: Number of training areas to instantiate
    :param options: parsed command line arguments
    """
    with hierarchical_timer("run_training.setup"):
        torch_utils.set_torch_config(options.torch_settings)
        checkpoint_settings = options.checkpoint_settings
        env_settings = options.env_settings
        engine_settings = options.engine_settings

        run_logs_dir = checkpoint_settings.run_logs_dir
        port: Optional[int] = env_settings.base_port
        # Check if directory exists
        validate_existing_directories(
            checkpoint_settings.write_path,
            checkpoint_settings.resume,
            checkpoint_settings.force,
            checkpoint_settings.maybe_init_path,
        )
        # Make run logs directory
        os.makedirs(run_logs_dir, exist_ok=True)
        # Load any needed states in case of resume
        if checkpoint_settings.resume:
            GlobalTrainingStatus.load_state(
                os.path.join(run_logs_dir, "training_status.json")
            )
        # In case of initialization, set full init_path for all behaviors
        elif checkpoint_settings.maybe_init_path is not None:
            setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path)

        # Configure Tensorboard Writers and StatsReporter
        stats_writers = register_stats_writer_plugins(options)
        for sw in stats_writers:
            StatsReporter.add_writer(sw)

        if env_settings.env_path is None:
            port = None
        env_factory = create_environment_factory(
            env_settings.env_path,
            engine_settings.no_graphics,
            run_seed,
            num_areas,
            port,
            env_settings.env_args,
            os.path.abspath(run_logs_dir),  # Unity environment requires absolute path
        )

        env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs)
        env_parameter_manager = EnvironmentParameterManager(
            options.environment_parameters, run_seed, restore=checkpoint_settings.resume
        )

        trainer_factory = TrainerFactory(
            trainer_config=options.behaviors,
            output_path=checkpoint_settings.write_path,
            train_model=not checkpoint_settings.inference,
            load_model=checkpoint_settings.resume,
            seed=run_seed,
            param_manager=env_parameter_manager,
            init_path=checkpoint_settings.maybe_init_path,
            multi_gpu=False,
        )
        # Create controller and begin training.
        tc = TrainerController(
            trainer_factory,
            checkpoint_settings.write_path,
            checkpoint_settings.run_id,
            env_parameter_manager,
            not checkpoint_settings.inference,
            run_seed,
        )

    # Begin training
    try:
        tc.start_learning(env_manager)
    finally:
        env_manager.close()
        write_run_options(checkpoint_settings.write_path, options)
        write_timing_tree(run_logs_dir)
        write_training_status(run_logs_dir)
예제 #2
0
def run_training(run_seed: int, options: RunOptions) -> None:
    """
    Launches training session.
    :param options: parsed command line arguments
    :param run_seed: Random seed used for training.
    :param run_options: Command line arguments for training.
    """
    with hierarchical_timer("run_training.setup"):
        checkpoint_settings = options.checkpoint_settings
        env_settings = options.env_settings
        engine_settings = options.engine_settings
        base_path = "results"
        write_path = os.path.join(base_path, checkpoint_settings.run_id)
        maybe_init_path = (
            os.path.join(base_path, checkpoint_settings.initialize_from)
            if checkpoint_settings.initialize_from is not None
            else None
        )
        run_logs_dir = os.path.join(write_path, "run_logs")
        port: Optional[int] = env_settings.base_port
        # Check if directory exists
        validate_existing_directories(
            write_path,
            checkpoint_settings.resume,
            checkpoint_settings.force,
            maybe_init_path,
        )
        # Make run logs directory
        os.makedirs(run_logs_dir, exist_ok=True)
        # Load any needed states
        if checkpoint_settings.resume:
            GlobalTrainingStatus.load_state(
                os.path.join(run_logs_dir, "training_status.json")
            )

        # Configure Tensorboard Writers and StatsReporter
        tb_writer = TensorboardWriter(
            write_path, clear_past_data=not checkpoint_settings.resume
        )
        gauge_write = GaugeWriter()
        console_writer = ConsoleWriter()
        StatsReporter.add_writer(tb_writer)
        StatsReporter.add_writer(gauge_write)
        StatsReporter.add_writer(console_writer)

        if env_settings.env_path is None:
            port = None
        env_factory = create_environment_factory(
            env_settings.env_path,
            engine_settings.no_graphics,
            run_seed,
            port,
            env_settings.env_args,
            os.path.abspath(run_logs_dir),  # Unity environment requires absolute path
        )
        engine_config = EngineConfig(
            width=engine_settings.width,
            height=engine_settings.height,
            quality_level=engine_settings.quality_level,
            time_scale=engine_settings.time_scale,
            target_frame_rate=engine_settings.target_frame_rate,
            capture_frame_rate=engine_settings.capture_frame_rate,
        )
        env_manager = SubprocessEnvManager(
            env_factory, engine_config, env_settings.num_envs
        )
        env_parameter_manager = EnvironmentParameterManager(
            options.environment_parameters, run_seed, restore=checkpoint_settings.resume
        )

        trainer_factory = TrainerFactory(
            trainer_config=options.behaviors,
            output_path=write_path,
            train_model=not checkpoint_settings.inference,
            load_model=checkpoint_settings.resume,
            seed=run_seed,
            param_manager=env_parameter_manager,
            init_path=maybe_init_path,
            multi_gpu=False,
        )
        # Create controller and begin training.
        tc = TrainerController(
            trainer_factory,
            write_path,
            checkpoint_settings.run_id,
            env_parameter_manager,
            not checkpoint_settings.inference,
            run_seed,
        )

    # Begin training
    try:
        tc.start_learning(env_manager)
    finally:
        env_manager.close()
        write_run_options(write_path, options)
        write_timing_tree(run_logs_dir)
        write_training_status(run_logs_dir)
예제 #3
0
def test_existing_directories(tmp_path):
    output_path = os.path.join(tmp_path, "runid")
    # Test fresh new unused path - should do nothing.
    validate_existing_directories(output_path, False, False)
    # Test resume with fresh path - should throw an exception.
    with pytest.raises(UnityTrainerException):
        validate_existing_directories(output_path, True, False)

    # make a directory
    os.mkdir(output_path)
    # Test try to train w.o. force, should complain
    with pytest.raises(UnityTrainerException):
        validate_existing_directories(output_path, False, False)
    # Test try to train w/ resume - should work
    validate_existing_directories(output_path, True, False)
    # Test try to train w/ force - should work
    validate_existing_directories(output_path, False, True)

    # Test initialize option
    init_path = os.path.join(tmp_path, "runid2")
    with pytest.raises(UnityTrainerException):
        validate_existing_directories(output_path, False, True, init_path)
    os.mkdir(init_path)
    # Should pass since the directory exists now.
    validate_existing_directories(output_path, False, True, init_path)