Exemplo n.º 1
0
def make_atari_env_watch(args):
    return wrap_deepmind(
        args.task,
        frame_stack=args.frames_stack,
        episode_life=False,
        clip_rewards=False
    )
Exemplo n.º 2
0
def make_atari_env_watch(args):
    environment = wrap_deepmind(args.task,
                                frame_stack=args.frames_stack,
                                episode_life=False,
                                clip_rewards=False)
    if args.invert_reward:
        environment = InverseReward(environment)
    return environment
Exemplo n.º 3
0
def evaluate(policy_model,
             action_space,
             episode_true,
             epsilon=0.01,
             num_episode=10):
    """ Evaluate the performance of the agent

        policy_model: the agent model
        action_space: action space dimension of game
        episode_true: the training episodes
        epsilon : the probability of agent choose random action
        num_episode : the test episodes of evaluation process
    """
    env = atari_wrapper.make_atari('RiverraidNoFrameskip-v4')
    env = atari_wrapper.wrap_deepmind(env,
                                      clip_rewards=False,
                                      frame_stack=True,
                                      pytorch_img=True)
    test_scores = []
    score = 0
    episode = 0
    while episode < num_episode:

        observation = env.reset()
        done = False

        while not done:
            t_observation = trace.from_numpy(observation).float() / 255
            t_observation = t_observation.view(
                1,
                t_observation.shape[0],
                # t_observation.shape:torch.Size([1, 4, 84, 84])
                t_observation.shape[1],
                t_observation.shape[2])
            if random.random() > epsilon:  # choose action by epsilon-greedy
                q_value = policy_model(t_observation)
                action = q_value.argmax(1).data.cpu().numpy().astype(int)[0]
            else:
                action = random.sample(range(len(action_space)), 1)[0]

            next_observation, reward, done, info = env.step(
                action_space[action])
            observation = next_observation
            score += reward

        if info['ale.lives'] == 0:
            test_scores.append(score)
            episode += 1
            score = 0

    f = open("file.txt", 'a')
    f.write("%f, %d, %d\n" % (float(sum(test_scores)) / float(num_episode),
                              episode_true, num_episode))
    f.close()

    mean_reward = float(sum(test_scores)) / float(num_episode)

    return mean_reward
Exemplo n.º 4
0
        def sub_env_creator():
            sub_env = make_atari(env_name)
            sub_env.seed(seed + env_num)

            if env_num == 0 and num_envs > 1:
                # Wrap first env in default monitor for video output
                # Results will be transformed into baselines monitor style at the end of the run
                sub_env = gym.wrappers.Monitor(sub_env, results_save_dir)
            else:
                # Wrap every other env in the baselines monitor for equivalent plotting.
                sub_env = Monitor(sub_env, join(results_save_dir, str(env_num)))

            sub_env = wrap_deepmind(sub_env, frame_stack=True, scale=True)

            return sub_env
 def thunk():
     # env = gym.make(gym_id)
     env = wrap_atari(gym_id)
     env = gym.wrappers.RecordEpisodeStatistics(env)
     if args.capture_video:
         if idx == 0:
             env = Monitor(env, f'videos/{experiment_name}')
     env = wrap_pytorch(
         wrap_deepmind(
             env,
             clip_rewards=True,
             frame_stack=True,
             scale=False,
         ))
     env.seed(seed)
     env.action_space.seed(seed)
     env.observation_space.seed(seed)
     return env
Exemplo n.º 6
0
def make_atari_env(args):
    return wrap_deepmind(args.task, frame_stack=args.frames_stack)
Exemplo n.º 7
0
def main(argv):
    env = make_atari(FLAGS.env)
    env = wrap_deepmind(env, frame_stack=True)

    if FLAGS.agent == 'Rainbow':
        FLAGS.network = 'Dueling_Net'
        FLAGS.multi_step = 3
        FLAGS.category = True
        FLAGS.noise = True

    message = OrderedDict({
        "Env": env,
        "Agent": FLAGS.agent,
        "Network": FLAGS.network,
        "Episode": FLAGS.n_episode,
        "Max_Step": FLAGS.step,
        "Categorical": FLAGS.category,
        "init_model": FLAGS.model
    })

    out_dim = set_output_dim(FLAGS, env.action_space.n)

    agent = eval(FLAGS.agent)(model=set_model(outdim=out_dim),
                              n_actions=env.action_space.n,
                              n_features=env.observation_space.shape,
                              learning_rate=0,
                              e_greedy=0,
                              reward_decay=0,
                              replace_target_iter=0,
                              e_greedy_increment=0,
                              optimizer=None,
                              network=FLAGS.network,
                              trainable=False,
                              is_categorical=FLAGS.category,
                              is_noise=FLAGS.noise,
                              gpu=find_gpu())

    if FLAGS.agent == 'PolicyGradient':
        trainer = PolicyTrainer(agent=agent,
                                env=env,
                                n_episode=FLAGS.n_episode,
                                max_step=FLAGS.step,
                                replay_size=0,
                                data_size=0,
                                n_warmup=0,
                                priority=None,
                                multi_step=0,
                                render=FLAGS.render,
                                test_episode=5,
                                test_interval=0,
                                test_frame=FLAGS.rec,
                                test_render=FLAGS.test_render,
                                metrics=message,
                                init_model_dir=FLAGS.model)

    elif FLAGS.agent == 'A3C' or FLAGS.agent == 'Ape_X':
        trainer = DistributedTrainer(agent=agent,
                                     n_workers=0,
                                     env=env,
                                     n_episode=FLAGS.n_episode,
                                     max_step=FLAGS.step,
                                     replay_size=0,
                                     data_size=0,
                                     n_warmup=0,
                                     priority=None,
                                     multi_step=0,
                                     render=False,
                                     test_episode=5,
                                     test_interval=0,
                                     test_frame=FLAGS.rec,
                                     test_render=FLAGS.test_render,
                                     metrics=message,
                                     init_model_dir=FLAGS.model)

    else:
        trainer = Trainer(agent=agent,
                          env=env,
                          n_episode=FLAGS.n_episode,
                          max_step=FLAGS.step,
                          replay_size=0,
                          data_size=0,
                          n_warmup=0,
                          priority=None,
                          multi_step=0,
                          render=FLAGS.render,
                          test_episode=5,
                          test_interval=0,
                          test_frame=FLAGS.rec,
                          test_render=FLAGS.test_render,
                          metrics=message,
                          init_model_dir=FLAGS.model)

    trainer.test()

    return
Exemplo n.º 8
0
def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("use_cuda: ", use_cuda)
    print("Device: ", device)

    env = atari_wrapper.make_atari('RiverraidNoFrameskip-v4')
    env = atari_wrapper.wrap_deepmind(env,
                                      clip_rewards=False,
                                      frame_stack=True,
                                      pytorch_img=True)

    action_space = [a for a in range(env.action_space.n)]
    n_action = len(action_space)

    # DQN Model and optimizer:
    policy_model = DQNModel().to(device)
    target_model = DQNModel().to(device)
    target_model.load_state_dict(policy_model.state_dict())

    optimizer = torch.optim.RMSprop(policy_model.parameters(),
                                    lr=lr,
                                    alpha=alpha)

    # Initialize the Replay Buffer
    replay_buffer = ReplayBuffer(rep_buf_size)

    while len(replay_buffer) < rep_buf_ini:

        observation = env.reset()
        done = False

        while not done:
            with torch.no_grad():
                t_observation = torch.from_numpy(observation).float().to(
                    device)
                t_observation = t_observation.view(1, t_observation.shape[0],
                                                   t_observation.shape[1],
                                                   t_observation.shape[2])
                action = random.sample(range(len(action_space)), 1)[0]

            next_observation, reward, done, info = env.step(
                action_space[action])

            replay_buffer.push(observation, action, reward, next_observation,
                               done)
            observation = next_observation

    print('Experience Replay buffer initialized')

    # Use log to record the performance
    logger = logging.getLogger('dqn_Riverraid')
    logger.setLevel(logging.INFO)
    logger_handler = logging.FileHandler('./dqn_Riverraid.log')
    logger.addHandler(logger_handler)

    # Training part
    env.reset()
    score = 0
    episode_score = []
    mean_episode_score = []
    episode_true = 0
    num_frames = 0
    episode = 0
    last_100episode_score = deque(maxlen=100)

    while episode < max_episodes:

        observation = env.reset()
        done = False
        # import time
        # start=time.time()

        while not done:

            with torch.no_grad():

                t_observation = torch.from_numpy(observation).float().to(
                    device) / 255
                t_observation = t_observation.view(1, t_observation.shape[0],
                                                   t_observation.shape[1],
                                                   t_observation.shape[2])
                epsilon = epsilon_by_frame(num_frames)
                if random.random() > epsilon:
                    q_value = policy_model(t_observation)
                    action = q_value.argmax(1).data.cpu().numpy().astype(
                        int)[0]
                else:
                    action = random.sample(range(len(action_space)), 1)[0]

            next_observation, reward, done, info = env.step(
                action_space[action])
            num_frames += 1
            score += reward

            replay_buffer.push(observation, action, reward, next_observation,
                               done)
            observation = next_observation

            # Update policy
            if len(replay_buffer
                   ) > batch_size and num_frames % skip_frame == 0:
                observations, actions, rewards, next_observations, dones = replay_buffer.sample(
                    batch_size)

                observations = torch.from_numpy(np.array(observations) /
                                                255).float().to(device)

                actions = torch.from_numpy(
                    np.array(actions).astype(int)).float().to(device)
                actions = actions.view(actions.shape[0], 1)

                rewards = torch.from_numpy(
                    np.array(rewards)).float().to(device)
                rewards = rewards.view(rewards.shape[0], 1)

                next_observations = torch.from_numpy(
                    np.array(next_observations) / 255).float().to(device)

                dones = torch.from_numpy(
                    np.array(dones).astype(int)).float().to(device)
                dones = dones.view(dones.shape[0], 1)

                q_values = policy_model(observations)
                next_q_values = target_model(next_observations)

                q_value = q_values.gather(1, actions.long())
                next_q_value = next_q_values.max(1)[0].unsqueeze(1)
                expected_q_value = rewards + gamma * next_q_value * (1 - dones)

                loss = huber_loss(q_value, expected_q_value)

                optimizer.zero_grad()
                loss.backward()

                optimizer.step()

                for target_param, policy_param in zip(
                        target_model.parameters(), policy_model.parameters()):
                    target_param.data.copy_(TAU * policy_param.data +
                                            (1 - TAU) * target_param.data)

        episode += 1
        # episode_score.append(score)
        # end=time.time()
        # print("Running time ( %i episode): %.3f Seconds "%(episode ,end-start))

        if info['ale.lives'] == 0:
            # episode_score.append(score)
            mean_score = score
            episode_true += 1
            score = 0

            # if episode % 20 == 0:
            # mean_score = np.mean(episode_score)
            mean_episode_score.append(mean_score)
            last_100episode_score.append(mean_score)
            # episode_score = []
            logger.info('Frame: ' + str(num_frames) + ' / Episode: ' +
                        str(episode_true) + ' / Average Score : ' +
                        str(int(mean_score)) + '   / epsilon: ' +
                        str(float(epsilon)))
            #plot_score(mean_episode_score, episode_true)
            pickle.dump(mean_episode_score,
                        open('./dqn_Riverraid_mean_scores.pickle', 'wb'))
            if episode_true % 50 == 1:
                logger.info('Frame: ' + str(num_frames) + ' / Episode: ' +
                            str(episode_true) + ' / Average Score : ' +
                            str(int(mean_score)) + '   / epsilon: ' +
                            str(float(epsilon)) +
                            '   / last_100episode_score: ' +
                            str(float(np.mean(last_100episode_score))))

        if episode % 50 == 0:
            torch.save(target_model.state_dict(),
                       './dqn_spaceinvaders_target_model_state_dict.pt')
            torch.save(policy_model.state_dict(),
                       './dqn_spaceinvaders_model_state_dict.pt')

    pass
Exemplo n.º 9
0
def main():
    # Initialize environment and :
    env = atari_wrapper.make_atari('RiverraidNoFrameskip-v4')
    env = atari_wrapper.wrap_deepmind(env,
                                      clip_rewards=True,
                                      frame_stack=True,
                                      pytorch_img=True)
    action_space = [a for a in range(env.action_space.n)]

    # Initialize DQN Model and optimizer:
    policy_model = DQN_craft()
    target_model = DQN_craft()
    print(policy_model)
    target_model.eval()
    target_model.load_seq_list(target_model.seq_list())
    optimizer = optim.RMSprop(policy_model.parameters(), lr=lr, alpha=alpha)

    # -------------------------------------------------
    # Initialize the Replay Buffer
    replay_buffer = ReplayBuffer_Init(rep_buf_size, rep_buf_ini, env,
                                      action_space)

    # Use log to record the performance
    logger = logging.getLogger('dqn_Riverraid')
    logger.setLevel(logging.INFO)
    logger_handler = logging.FileHandler('./dqn_Riverraid.log')
    logger.addHandler(logger_handler)

    # --------------------------------------------------------------------------------------------------------------
    # Training part, Initialization below
    env.reset()
    score = 0
    episode_scores = []  # A list to record all episode_true score
    episode_true = 0  # we regard end of life = end of episode, since there are 4 lives in RiverRaid, thus,
    # one episode_ture = 4 episodes
    num_frames = 0
    episode = 0
    average_100_episode = []  # For plot
    max_100_episode = []  # For plot
    frame_1000 = []  # For plot
    last_25episode_score = deque(maxlen=25)  # for plot
    loss_list = []  # for plot
    loss_running = 0  # for log
    # End of initialization
    # --------------------------------------------------------------------------------------------------------------

    while episode < max_episodes:

        observation = env.reset()
        done = False

        while not done:

            t_observation = trace.from_numpy(observation).float() / 255
            # t_observation = t_observation.view(1, t_observation.shape[0],
            #                                     t_observation.shape[1],
            #                                     t_observation.shape[
            #                                         2])  # t_observation.shape:torch.Size([1, 4, 84, 84])
            t_observation = t_observation.unsqueeze(0)
            epsilon = epsilon_by_frame(num_frames)
            if random.random() > epsilon:  # choose action by epsilon-greedy
                q_value = policy_model(t_observation)
                action = q_value.argmax(1).data.numpy().astype(int)[0]
            else:
                action = random.sample(range(len(action_space)), 1)[0]
            # Store experience in the replay buffer
            next_observation, reward, done, info = env.step(
                action_space[action])
            replay_buffer.push(observation, action, reward, next_observation,
                               done)
            observation = next_observation
            num_frames += 1  # update frame
            score += reward

            # Update policy
            if len(replay_buffer
                   ) > batch_size and num_frames % skip_frame == 0:
                observations, actions, rewards, next_observations, dones = replay_buffer.sample(
                    batch_size)

                observations = trace.from_numpy(np.array(observations) /
                                                255).float()

                actions = trace.from_numpy(
                    np.array(actions).astype(int)).float()
                actions = actions.view(actions.shape[0],
                                       1)  # torch.Size([32, 1])

                rewards = trace.from_numpy(np.array(rewards)).float()
                rewards = rewards.view(rewards.shape[0],
                                       1)  # torch.Size([32, 1])

                next_observations = trace.from_numpy(
                    np.array(next_observations) / 255).float()

                dones = trace.from_numpy(np.array(dones).astype(int)).float()
                dones = dones.view(dones.shape[0], 1)  # torch.Size([32, 1])

                q_values = policy_model(observations)  # torch.Size([32, 18])
                next_q_values = target_model(
                    next_observations)  # torch.Size([32, 18])

                q_value = q_values.Gather(
                    actions.squeeze().long())  # torch.Size([32, 1])
                next_q_value = next_q_values.max(1)[0].unsqueeze(
                    1)  # torch.Size([32, 1])
                expected_q_value = rewards + gamma * next_q_value * (1 - dones)
                """
                if Double_dqn:  # Whether use double dqn
                    selected_action = policy_model(next_observations).argmax(dim=1, keepdim=True)
                    next_q_value = next_q_values.gather(1, selected_action)
                else:
                    next_q_value = next_q_values.max(1)[0].unsqueeze(1)  # torch.Size([32, 1])
                """
                mse = nn.MSE()
                loss = mse(q_value, expected_q_value)
                # loss = huber_loss(q_value, expected_q_value)
                loss_running = loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print(f"\rloss: {loss}", end='')
                # Soft update or Hard update (common dqn)
                if Soft_update:
                    for target_param, policy_param in zip(
                            target_model.parameters(),
                            policy_model.parameters()):
                        target_param.data.copy_(TAU * policy_param.data +
                                                (1 - TAU) * target_param.data)
                else:
                    if num_frames % target_update == 0:
                        # target_model.load_state_dict(policy_model.state_dict())
                        target_model.load_seq_list(target_model.seq_list())

        episode += 1
        # End of episode_true, reset the score
        if info['ale.lives'] == 0:
            episode_scores.append(score)
            last_25episode_score.append(score)

            # record in the log
            logger.info('Frame: ' + str(num_frames) + '  / Episode: ' +
                        str(episode_true) + '  / Average Score : ' +
                        str(int(score)) + '  / epsilon: ' +
                        str(float(epsilon)) + '  / loss : ' +
                        str(float(loss_running)))
            pickle.dump(episode_scores,
                        open('./dqn_Riverraid_mean_scores.pickle', 'wb'))

            episode_true += 1
            score = 0

            print(
                '\r Episode_true {} \t Average Score(last 100 episodes) {:.2f} '
                .format(episode_true, np.mean(last_25episode_score)),
                end=" ")
            # Update the log
            if episode_true % 25 == 1:
                logger.info('Frame: ' + str(num_frames) + '  / Episode: ' +
                            str(episode_true) + '  / Average Score : ' +
                            '         ' + '  / epsilon: ' +
                            str(float(epsilon)) +
                            '  / last_100episode_score: ' +
                            str(float(np.mean(last_25episode_score))))

                print("episode_ture: ", episode_true, "average_100_episode :",
                      np.mean(last_25episode_score))
                print("episode:", episode)
            # This "if " is for plot (to store the data of per iteration )
            if episode_true % 25 == 0:
                average_100_episode.append(np.mean(last_25episode_score))
                max_100_episode.append(np.max(last_25episode_score))
                loss_list.append(loss_running)
                frame_1000.append(num_frames / 1000.)
            # plot the scores and loss and save the picture
            if episode_true % 500 == 0:
                plt_result(average_100_episode, max_100_episode, frame_1000)
                plt_loss(loss_list, frame_1000)
            # Evaluation Part
            if Evaluation:
                if episode_true % evaluate_frequency == 0:
                    test_score = evaluate(policy_model,
                                          action_space,
                                          episode_true,
                                          epsilon=evaluate_epsilon,
                                          num_episode=evaluate_episodes)
                    print("test_score : ", test_score, "  ", "test episodes: ",
                          evaluate_episodes)

                    if test_score > test_stander:  # Save the model if the test score > test_stander
                        trace.save(
                            './dqn_RiverRaid_policy_model_state_dict.pth',
                            policy_model.seq_list())
                        print("Test score > %d , stop train" % test_stander)
                        break
        # Save the model
        if episode % save_frequency == 0:
            trace.save('./dqn_RiverRaid_policy_model_state_dict.pth',
                       policy_model.seq_list())

    plt_result(average_100_episode, max_100_episode, frame_1000)
    plt_loss(loss_list, frame_1000)
    pass
Exemplo n.º 10
0
from DQL import DQL_agent
import timeit
import atari_wrapper

env_to_use = 'Breakout-ram-v4'
# game parameters
env = atari_wrapper.make_atari(env_to_use)
env = atari_wrapper.wrap_deepmind(env,
                                  episode_life=True,
                                  clip_rewards=False,
                                  frame_stack=False,
                                  scale=True)

state_space = 128  # Using the ram input
action_space = 4
'''

env.step() -> returns array (state,reward,done?,_info)

Action State for Time pilot
action=1 -> going straight
action=2 -> going up no fire
action=3 -> going right no fire
action=4 -> going left no fire
'''

#We initialize our agent

agent = DQL_agent(state_space=state_space, action_space=action_space)
reward_list = []
eps_length_list = []
Exemplo n.º 11
0
def play(env_id, model_path, max_ep, video):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if get_env_type(env_id) == 'atari':
        env = make_atari(env_id)
        env = wrap_deepmind(env, False, False, False, False)
        env = wrap_pytorch(env)

        model_type = 'conv'
    else:
        env = gym.make(env_id)

        model_type = 'linear'

    obs_shape = env.observation_space.shape
    num_actions = env.action_space.n

    agent = DQN(obs_shape, num_actions, device=device, model=model_type)
    agent.load(model_path)

    policy = Greedy(agent)

    ep = 1
    episode_reward = 0

    obs = env.reset()
    screen = env.render(mode='rgb_array')
    if video:
        writer = skvideo.io.FFmpegWriter(f'videos/{env_id}-ep-{ep}.mp4')

    for t in count():

        action = policy.act(obs, t)
        next_obs, reward, done, _ = env.step(action)

        episode_reward += reward

        screen = env.render(mode='rgb_array')
        if video:
            writer.writeFrame(screen)

        obs = next_obs

        if done:

            print(f'ep: {ep:4} reward: {episode_reward}')

            if ep >= max_ep:
                break

            ep += 1
            episode_reward = 0
            ebs = env.reset()

            if video:
                writer.close()
                writer = skvideo.io.FFmpegWriter(
                    f'videos/{env_id}-ep-{ep}.mp4')

    if video:
        writer.close()
    env.close()
Exemplo n.º 12
0
def make_atari_env(args):
    environment = wrap_deepmind(args.task, frame_stack=args.frames_stack)
    if args.invert_reward:
        environment = InverseReward(environment)
    return environment
Exemplo n.º 13
0
def make_atari_env(args):
    environment = wrap_deepmind(args.task, frame_stack=args.frames_stack)
    return environment
Exemplo n.º 14
0
def train(env_id,
          lr=1e-4,
          gamma=0.99,
          memory_size=1000,
          batch_size=32,
          train_timesteps=10000,
          train_start_time=1000,
          target_update_frequency=1000,
          init_epsilon=1,
          final_epsilon=0.1,
          epsilon_decay=300,
          model_path=None):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    LOG_PATH = f'logs/dqn_log_{env_id}.txt'

    if get_env_type(env_id) == 'atari':
        env = make_atari(env_id)
        env = wrap_deepmind(env)
        env = wrap_pytorch(env)

        model_type = 'conv'
    else:
        env = gym.make(env_id)

        model_type = 'linear'

    obs_shape = env.observation_space.shape
    num_actions = env.action_space.n

    memory = ReplayBuffer(memory_size)

    agent = DQN(obs_shape, num_actions, lr, gamma, device, model_type)
    policy = EpsilonGreedy(agent, num_actions, init_epsilon, final_epsilon,
                           epsilon_decay)

    # populate replay memory
    obs = env.reset()
    for t in range(train_start_time):

        # uniform random policy
        action = random.randrange(num_actions)
        next_obs, reward, done, _ = env.step(action)
        memory.add(obs, action, reward, next_obs, done)

        obs = next_obs

        if done:
            # start a new episode
            obs = env.reset()

    # for monitoring
    ep_num = 1
    ep_start_time = 1
    episode_reward = 0
    reward_list = []

    # train start
    obs = env.reset()
    for t in tqdm.tqdm(range(1, train_timesteps + 1)):

        # choose action
        action = policy.act(obs, t)
        next_obs, reward, done, _ = env.step(action)
        memory.add(obs, action, reward, next_obs, done)

        obs = next_obs

        # sample batch transitions from memory
        transitions = memory.sample(batch_size)
        # train
        loss = agent.train(transitions)

        # record reward
        episode_reward += reward

        # update target network at every C timesteps
        if t % target_update_frequency == 0:
            agent.update_target()

        if done:
            # start a new episode
            obs = env.reset()

            # write log
            with open(LOG_PATH, 'a') as f:
                f.write(f'{ep_num}\t{episode_reward}\t{ep_start_time}\t{t}\n')

            if model_path is not None:
                # save model
                info = {
                    'epoch': ep_num,
                    'timesteps': t,
                }
                agent.save(model_path, info)

            ep_num += 1
            ep_start_time = t + 1
            reward_list.append(episode_reward)
            episode_reward = 0
        advantage, returns = generalized_advantage(rewards, masks, values)

        # normalize retuns
        optimizer.zero_grad()
        loss_p = (-log_probs * advantage).sum(axis=1).mean()
        loss_v = F.mse_loss(values, returns,
                            reduction='none').sum(axis=1).mean()
        loss = loss_p + loss_v
        loss.backward()
        optimizer.step()

        mean_episode_reward = (rewards * masks).sum(axis=1).mean().item()
        running_reward = (1-alpha)*mean_episode_reward \
                    + alpha*running_reward
        if i % 10:
            print("Episode:{}\t Mean Episode Reward:{} \t\
                        Running Reward:{}".format(i, mean_episode_reward,
                                                  running_reward))

        if np.abs(running_reward) > 20:
            print("Model converged in", time.time() - start_time)
            rewards, _, _, _, _ = \
                        collect_samples(env, policy, 1, render=True)
            break
    return policy


if __name__ == '__main__':
    env = wrap_deepmind(make_atari('PongNoFrameskip-v4'), scale=True)
    vanilla_pg_with_GAE(env)