コード例 #1
0
def create_rl_experiment(experiment_config):
    """Creates a new reinforcement learning `Experiment` instance.

    Args:
        experiment_config: the config to use for creating the experiment.
    """
    agent = getters.get_agent(experiment_config.agent_config,
                              experiment_config.model_config,
                              experiment_config.run_config)
    env = getters.get_environment(experiment_config.environment_config.module,
                                  experiment_config.environment_config.env_id,
                                  **experiment_config.environment_config.params)
    train_hooks = getters.get_hooks(experiment_config.train_hooks_config)
    eval_hooks = getters.get_hooks(experiment_config.eval_hooks_config)

    experiment = RLExperiment(
        agent=agent,
        env=env,
        train_steps=experiment_config.train_steps,
        train_episodes=experiment_config.train_episodes,
        first_update=experiment_config.first_update,
        update_frequency=experiment_config.update_frequency,
        eval_steps=experiment_config.eval_steps,
        train_hooks=train_hooks,
        eval_hooks=eval_hooks,
        eval_delay_secs=experiment_config.eval_delay_secs,
        continuous_eval_throttle_secs=experiment_config.continuous_eval_throttle_secs,
        eval_every_n_steps=experiment_config.eval_every_n_steps,
        delay_workers_by_global_step=experiment_config.delay_workers_by_global_step,
        export_strategies=experiment_config.export_strategies,
        train_steps_per_iteration=experiment_config.train_steps_per_iteration)

    return experiment
コード例 #2
0
ファイル: experiment.py プロジェクト: vdt/polyaxon
def create_experiment(experiment_config):
    """Creates a new `Experiment` instance.

    Args:
        experiment_config: the config to use for creating the experiment.
    """
    # Creates training input function
    train_input_data_config = experiment_config.train_input_data_config
    train_input_fn = create_input_data_fn(
        pipeline_config=train_input_data_config.pipeline_config,
        mode=ModeKeys.TRAIN,
        scope='train_input_fn',
        input_type=train_input_data_config.input_type,
        x=train_input_data_config.x,
        y=train_input_data_config.y)

    # Creates eval_input_fn input function
    eval_input_data_config = experiment_config.eval_input_data_config
    eval_input_fn = create_input_data_fn(
        pipeline_config=eval_input_data_config.pipeline_config,
        mode=ModeKeys.EVAL,
        scope='eval_input_fn',
        input_type=eval_input_data_config.input_type,
        x=eval_input_data_config.x,
        y=eval_input_data_config.y)

    estimator = getters.get_estimator(experiment_config.estimator_config,
                                      experiment_config.model_config,
                                      experiment_config.run_config)
    train_hooks = getters.get_hooks(experiment_config.train_hooks_config)
    eval_hooks = getters.get_hooks(experiment_config.eval_hooks_config)

    experiment = Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        train_steps=experiment_config.train_steps,
        eval_steps=experiment_config.eval_steps,
        train_hooks=train_hooks,
        eval_hooks=eval_hooks,
        eval_delay_secs=experiment_config.eval_delay_secs,
        continuous_eval_throttle_secs=experiment_config.
        continuous_eval_throttle_secs,
        eval_every_n_steps=experiment_config.eval_every_n_steps,
        delay_workers_by_global_step=experiment_config.
        delay_workers_by_global_step,
        export_strategies=experiment_config.export_strategies,
        train_steps_per_iteration=experiment_config.train_steps_per_iteration)

    return experiment