コード例 #1
0
ファイル: duckie_test.py プロジェクト: Kuanta/Reinforcement
def test(args):
    duckie.logger.disabled = True # Disable log messages from ducki  
    env = DuckietownEnv(
        seed = None,
        map_name = "4way_bordered",
        max_steps = 500001,
        draw_curve = False,
        draw_bbox = False,
        domain_rand = False,
        randomize_maps_on_reset = False,
        accept_start_angle_deg = 4,
        full_transparency = True,
        user_tile_start = None,
        num_tris_distractors = 12,
        enable_leds = False,
    )

     # Load Encoder
    encoder = BetaVAE_H(10, 3)
    loaded_model = torch.load(args.encoder_path)
    encoder.load_state_dict(loaded_model['model_states']['net'])
    env = ResizeWrapper(env, 64, 64)
    env = SwapDimensionsWrapper(env)
    env = ImageNormalizeWrapper(env)
    env = TorchifyWrapper(env)
    env = EncoderWrapper(env, encoder)
    #env = ActionWrapper(env)
    env = GymEnvironment(env)

    state_size = 14
    act_size = env.gym_env.action_space.shape[0]
    action_def = ContinuousDefinition(env.gym_env.action_space.shape, \
        env.gym_env.action_space.high, \
        env.gym_env.action_space.low)



    multihead_net = DuckieNetwork(state_size, act_size)
   
    agent = SACAgent(multihead_net, action_def)
   
    agent.load_model(args.model_path)
   

    for i in range(args.n_episodes):
        total_reward = 0
        state = env.reset()
        state = torch.from_numpy(state).float().unsqueeze(0)
        while True:
            action = agent.act(state, evaluation=True)
            print(action)
            next_state, reward, done, _ = env.step(action)
            
            state = torch.from_numpy(next_state).float().unsqueeze(0)
            total_reward += reward
            env.render()
            if done:
                print("Total reward:{}".format(total_reward))
                break
コード例 #2
0
ファイル: ttt_action.py プロジェクト: tpvt99/sbcs5478
    model = PPO.load(osp.join(results_dir, "best_model", "best_model.zip"))

else:
    raise ValueError("Error model")

# Load the saved statistics
#  do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False

obs = env.reset()
steps = 0
rewards = 0
done, state = False, None
while True:
    # Get action
    env.render()
    action, state = model.predict(obs, state=state, deterministic=False)
    obs, reward, done, info = env.step(action)
    print(
        f'Step {steps} Action {action} with Reward {reward} with info {info}')

    steps += 1
    rewards += reward

    if done:
        break

print(steps)
print(rewards)