# Wrappers
env = ResizeWrapper(env)
env = NormalizeWrapper(env)
env = ImgWrapper(env)  # to make the images from 160x120x3 into 3x160x120
env = ActionWrapper(env)
# env = DtRewardWrapper(env) # not during testing

state_dim = env.observation_space.shape
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# Initialize policy
policy = DDPG(state_dim, action_dim, max_action, net_type="cnn")

policy.load(file_name, directory="./pytorch_models")

with torch.no_grad():
    while True:
        obs = env.reset()
        env.render()
        rewards = []
        while True:
            action = policy.predict(np.array(obs))
            obs, rew, done, misc = env.step(action)
            rewards.append(rew)
            env.render()
            if done:
                break
        print("mean episode reward:", np.mean(rewards))
obs = env.get_features()
EPISODES, STEPS = 20, 1000
DEBUG = False

# please notice
logger = Logger(env, log_file=f'train-{int(EPISODES*STEPS/1000)}k.log')

start_time = time.time()
print(
    f"[INFO]Starting to get logs for {EPISODES} episodes each {STEPS} steps..")
with torch.no_grad():
    # let's collect our samples
    for episode in range(0, EPISODES):
        for steps in range(0, STEPS):
            # we use our 'expert' to predict the next action.
            action = expert.predict(np.array(obs))
            # Apply the action
            observation, reward, done, info = env.step(action)
            # Get features(state representation) for RL agent
            obs = env.get_features()

            if done:
                print(f"#Episode: {episode}\t | #Step: {steps}")
                break

            closest_point, _ = env.closest_curve_point(env.cur_pos,
                                                       env.cur_angle)
            if closest_point is None:
                done = True
                break
            # Cut the horizon: obs.shape = (480,640,3) --> (300,640,3)