示例#1
0
文件: inference.py 项目: ulrikah/rave
def run_offline_inference(agent: Trainer, env: CrossAdaptiveEnv):
    # NOTE: something is wrong here. For some reason, all the action values are too close to the bound
    done = False
    obs = env.reset()
    while not done:
        action = agent.compute_action(obs)
        # TODO: standardize action
        # it might be difficult to standardize the action in live mode, but offline inference essentially work
        obs, _, done, _ = env.step(action)
示例#2
0
文件: inference.py 项目: ulrikah/rave
def run_live_inference(
    agent: Trainer,
    env: CrossAdaptiveEnv,
):
    mediator = Mediator()

    episode_index = 0
    while episode_index < 1500:
        source_features, target_features = mediator.get_features()
        if source_features is None or target_features is None:
            continue
        else:
            # trim off timestamp
            source_features = source_features[1:]
            target_features = target_features[1:]

        standardized_source = np.array([
            env.standardizer.get_standardized_value(
                env.analyser.analysis_features[i], feature_value)
            for i, feature_value in enumerate(source_features)
        ])
        standardized_target = np.array([
            env.standardizer.get_standardized_value(
                env.analyser.analysis_features[i], feature_value)
            for i, feature_value in enumerate(target_features)
        ])
        obs = np.concatenate((standardized_source, standardized_target))
        print(np.round(obs, decimals=2))
        action = agent.compute_action(obs)
        # action = env.action_space.sample()
        mapping = env.action_to_mapping(action)
        # print(mapping)
        mediator.send_effect_mapping(mapping)
        episode_index += 1
    mediator.terminate()
    print("\n\n\tDONE\n\n")