예제 #1
0
파일: test.py 프로젝트: ChrisFugl/DoomRL
def test(config, env):
    ob_space = env.observation_space
    ac_space = env.action_space
    tf.reset_default_graph()
    gpu_opts = tf.GPUOptions(allow_growth=True)
    tf_config = tf.ConfigProto(
        inter_op_parallelism_threads=1,
        intra_op_parallelism_threads=1,
        gpu_options=gpu_opts,
    )
    with tf.Session(config=tf_config) as sess:
        config.batch_size = 2
        config.number_of_steps = 2
        policy = build_policy(env, 'cnn')
        model = Model(policy=policy,
                      env=env,
                      nsteps=config.number_of_steps,
                      ent_coef=config.entropy_weight,
                      vf_coef=config.critic_weight,
                      max_grad_norm=config.max_grad_norm,
                      lr=config.learning_rate,
                      alpha=config.rmsp_decay,
                      epsilon=config.discount_factor,
                      total_timesteps=config.timesteps,
                      lrschedule='linear')
        model.load(config.load_path)
        return make_rollouts(config, env, model)
예제 #2
0
def enjoy(env_id, seed, policy, model_filename, fps=100):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy

    env = wrap_deepmind(make_atari(env_id), clip_rewards=False, frame_stack=True)
    env.seed(seed)

    tf.reset_default_graph()
    ob_space = env.observation_space
    ac_space = env.action_space
    nsteps = 5  # default value, change if needed

    model = Model(policy=policy_fn, ob_space=ob_space, ac_space=ac_space, nenvs=1, nsteps=nsteps)
    model.load(model_filename)

    while True:
        obs, done = env.reset(), False
        episode_rew = 0
        while not done:
            env.render()
            time.sleep(1.0 / fps)
            action, _, _, _ = model.step_model.step([obs.__array__()])
            obs, rew, done, _ = env.step(action)
            episode_rew += rew
        print('Episode reward:', episode_rew)

    env.close()
예제 #3
0
def enjoy(env_path, seed, max_steps, base_port, model_path):
    env = _wrap_unity_env(env_path, seed, base_port, rank=1)
    model = Model(policy=CnnPolicy,
                  ob_space=env.observation_space,
                  ac_space=env.action_space,
                  nenvs=1,
                  nsteps=5)
    model.load(
        model_path
    )  # This will cause an unknown error when loading a model trained with 'learn'

    step_count = 0
    while step_count <= max_steps:
        obs, done = env.reset(), False
        epsiode_rew = 0
        step_count += 1

        while not done:
            if keyboard.is_pressed('n'):
                break

            action, _, _, _ = model.step_model.step([obs.__array__()])
            obs, rew, done, _ = env.step(action)
            epsiode_rew += rew

        print('Episode reward: ', episode_rew)
        print('Step Count: ', step_count)

    env.close()
예제 #4
0
def main():
    env = gym.make("gridworld-v0")
    policy = CnnPolicy
    nsteps = 5
    total_timesteps = int(80e6)
    vf_coef = 0.5
    ent_coef = 0.01
    max_grad_norm = 0.5
    lr = 7e-4
    lrschedule = 'linear'
    epsilon = 1e-5
    alpha = 0.99
    gamma = 0.99
    log_interval = 100
    ob_space = env.observation_space
    ac_space = env.action_space
    nenvs = env.num_envs
    #with tf.Session() as sess:
    with tf.Graph().as_default(), tf.Session().as_default():
        #model = a2c.learn(policy=CnnPolicy, env=env,  total_timesteps=int(0), seed=0)
        model = Model(policy=policy,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      nenvs=nenvs,
                      nsteps=nsteps,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      max_grad_norm=max_grad_norm,
                      lr=lr,
                      alpha=alpha,
                      epsilon=epsilon,
                      total_timesteps=total_timesteps,
                      lrschedule=lrschedule)
        model.load(
            "/Users/constantinos/Documents/Projects/cmu_gridworld/cmu_gym/a2c_open.pkl"
        )

        while True:
            obs, done = env.reset(), False
            episode_rew = 0
            while not done:
                env.render()
                obs, rew, done, _ = env.step(model(obs[None])[0])
                episode_rew += rew
            print("Episode reward", episode_rew)
예제 #5
0
def run(policy,
        env,
        seed,
        nsteps=5,
        nstack=4,
        total_timesteps=int(80e6),
        vf_coef=0.5,
        ent_coef=0.01,
        max_grad_norm=0.5,
        lr=7e-4,
        lrschedule='linear',
        epsilon=1e-5,
        alpha=0.99,
        gamma=0.99,
        log_interval=100):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)
    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nenvs=nenvs,
                  nsteps=nsteps,
                  nstack=nstack,
                  num_procs=num_procs,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  max_grad_norm=max_grad_norm,
                  lr=lr,
                  alpha=alpha,
                  epsilon=epsilon,
                  total_timesteps=total_timesteps,
                  lrschedule=lrschedule)
    model.load('./model/a2c/model.h5')
    runner = Runner(env, model, nsteps=nsteps, nstack=nstack, gamma=gamma)
    while True:
        runner.run()
예제 #6
0
class ShootEnv(Env):
    def __init__(self):
        self.game = DoomGame()
        self.game.load_config('O:\\Doom\\scenarios\\cig_flat2.cfg')
        self.game.add_game_args(
            "-host 1 -deathmatch +timelimit 1.0 "
            "+sv_forcerespawn 1 +sv_noautoaim 1 +sv_respawnprotect 1 +sv_spawnfarthest 1 +sv_nocrouch 1 "
            "+viz_respawn_delay 0")

        self.game.set_mode(Mode.PLAYER)
        self.game.set_labels_buffer_enabled(True)
        self.game.set_depth_buffer_enabled(True)
        self.game.set_screen_resolution(ScreenResolution.RES_320X240)
        self.game.add_available_game_variable(GameVariable.FRAGCOUNT)

        #define navigation env
        class NavigatorSubEnv(Env):
            def __init__(self, game):
                self.action_space = Discrete(3)
                self.observation_space = Box(low=0,
                                             high=255,
                                             shape=(84, 84, 3),
                                             dtype=np.uint8)
                self._game = game

            def step(self, action):
                #-1 means it doesn't really controls the game
                if action > -1:
                    one_hot_action = [[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0],
                                      [0, 0, 0, 0, 1, 0]]
                    self._game.make_action(one_hot_action[action], 4)
                    if self._game.is_episode_finished():
                        self._game.new_episode()
                    if self._game.is_player_dead():
                        self._game.respawn_player()

                obs = get_observation(self._game.get_state())
                return get_observation(
                    self._game.get_state(),
                    real_frame=True), 0, check_enemy_enter(obs), None

            def seed(self, seed=None):
                pass

            def reset(self):
                return get_observation(self._game.get_state(), real_frame=True)

            def render(self, mode='human'):
                pass

        self.navigator = VecFrameStack(
            VecEnvAdapter([NavigatorSubEnv(self.game)]), 4)

        #define navigation network
        self.navigation_policy = Model(CnnPolicy,
                                       self.navigator.observation_space,
                                       self.navigator.action_space,
                                       nenvs=1,
                                       nsteps=1)
        self.navigation_policy.load(
            'O:\\Doom\\baselinemodel\\navigate_real2.dat')

        self.action_space = Discrete(3)  #turn L, turn R, fire
        self.observation_space = Box(low=0,
                                     high=255,
                                     shape=(84, 84, 3),
                                     dtype=np.uint8)
        self.available_actions = [[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0],
                                  [0, 0, 0, 0, 0, 1]]

    def seed(self, seed=None):
        self.game.set_seed(seed)
        self.game.init()
        self.game.send_game_command("removebots")
        for i in range(8):
            self.game.send_game_command("addbot")

    def reset(self):
        obs_for_navigator = self.navigator.reset()
        while True:
            actions, _, _, _ = self.navigation_policy.step(obs_for_navigator)
            obs_for_navigator, _, navi_done, _ = self.navigator.step(actions)
            if navi_done:
                break
        obs = get_observation(self.game.get_state())
        assert check_enemy_enter(obs)
        return get_observation(self.game.get_state(), real_frame=True)

    def step(self, action):
        old_fragcount = self.game.get_game_variable(GameVariable.FRAGCOUNT)
        self.game.make_action(self.available_actions[action], 4)
        new_fragcount = self.game.get_game_variable(GameVariable.FRAGCOUNT)
        rew = new_fragcount - old_fragcount
        done = False

        if self.game.is_episode_finished():
            done = True
            self.game.new_episode()
            self.game.send_game_command("removebots")
            for i in range(8):
                self.game.send_game_command("addbot")

        if self.game.is_player_dead():
            self.game.respawn_player()
            done = True

        if action == 2:  # fire
            rew -= 0.05

        state = self.game.get_state()
        obs = get_observation(state)

        if check_enemy_enter(obs):
            rew += 0.01

        if check_enemy_leave(obs):
            done = True

        return get_observation(state, real_frame=True), rew, done, None
예제 #7
0
def test(env_id, policy_name, seed, nstack=1, numAgents=2):
    iters = 100
    rwd = []
    percent_exp = []
    env = gym.make(env_id)
    env.seed(seed)
    print("logger dir: ", logger.get_dir())
    env = bench.Monitor(env,
                        logger.get_dir() and os.path.join(logger.get_dir()))
    if env_id == 'Pendulum-v0':
        if continuous_actions:
            env.action_space.n = env.action_space.shape[0]
        else:
            env.action_space.n = 10
    gym.logger.setLevel(logging.WARN)
    # img_shape = (84, 84, 3)
    img_shape = (84, 84, 3)
    ob_space = spaces.Box(low=0, high=255, shape=img_shape)
    ac_space = env.action_space

    # def get_img(env):
    #     ax, img = env.get_img()
    #    return ax, img

    # def process_img(img):
    #     img = rgb2grey(copy.deepcopy(img))
    #    img = resize(img, img_shape)
    #    return img

    policy_fn = policy_fn_name(policy_name)

    nsteps = 5
    total_timesteps = int(80e6)
    vf_coef = 0.5
    ent_coef = 0.01
    max_grad_norm = 0.5
    lr = 7e-4
    lrschedule = 'linear'
    epsilon = 1e-5
    alpha = 0.99
    continuous_actions = False
    debug = False
    if numAgents == 1:
        model = Model(policy=policy_fn,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      nenvs=1,
                      nsteps=nsteps,
                      nstack=nstack,
                      num_procs=1,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      max_grad_norm=max_grad_norm,
                      lr=lr,
                      alpha=alpha,
                      epsilon=epsilon,
                      total_timesteps=total_timesteps,
                      lrschedule=lrschedule,
                      continuous_actions=continuous_actions,
                      debug=debug)
        m_name = 'test_model_Mar7_1mil.pkl'
        model.load(m_name)
    else:
        model = []
        for i in range(numAgents):
            model.append(
                Model(policy=policy_fn,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      nenvs=1,
                      nsteps=nsteps,
                      nstack=nstack,
                      num_procs=1,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      max_grad_norm=max_grad_norm,
                      lr=lr,
                      alpha=alpha,
                      epsilon=epsilon,
                      total_timesteps=total_timesteps,
                      lrschedule=lrschedule,
                      continuous_actions=continuous_actions,
                      debug=debug,
                      itr=i))
        for i in range(numAgents):
            m_name = 'test_model_' + str(i) + '_300k.pkl'  # + '100k.pkl'
            model[i].load(m_name)
            print('---------------------------------------------')
            print("Successfully Loaded: ", m_name)
            print('---------------------------------------------')

    env.env, img = env.reset()
    rwd = [[], []]
    percent_exp = [[], []]
    for i in range(1, iters + 1):
        if i % 10 == 0:
            for j in range(numAgents):
                print('-----------------------------------')
                print('Agent ' + str(j))
                print('Iteration: ', i)
                avg_rwd = sum(rwd[j]) / i
                avg_pct_exp = sum(percent_exp[j]) / i
                med_pct_exp = statistics.median(percent_exp[j])
                print('Average Reward: ', avg_rwd)
                print('Average Percent Explored: ', avg_pct_exp, '%')
                print('Median Percent Explored: ', med_pct_exp)
                print('-----------------------------------')
        # ax, img = get_img(env)
        img_hist = []
        frames_dir = []
        for j in range(numAgents):
            frames_dir.append('exp_frames' + str(j * 100 + i + 200))
            if os.path.exists(frames_dir[j]):
                # raise ValueError('Frames directory already exists.')
                shutil.rmtree(frames_dir[j])
            os.makedirs(frames_dir[j])
            img_hist.append(deque([img[j] for _ in range(4)], maxlen=nstack))
        action = 0
        total_rewards = [0, 0]
        nstack = 1
        for tidx in range(1000):
            # if tidx % nstack == 0:
            for j in range(numAgents):
                if tidx > 0:
                    input_imgs = np.expand_dims(
                        np.squeeze(np.stack(img_hist, -1)), 0)
                    # print(np.shape(input_imgs))
                    # plt.imshow(input_imgs[0, :, :, 0])
                    # plt.imshow(input_imgs[0, :, :, 1])
                    # plt.draw()
                    # plt.pause(0.000001)
                    if input_imgs.shape == (1, 84, 84, 3):
                        actions, values, states = model[j].step_model.step(
                            input_imgs)
                    else:
                        actions, values, states = model[j].step_model.step(
                            input_imgs[:, :, :, :, 0])
                    # actions, values, states = model.step_model.step(input_imgs)
                    action = actions[0]
                    value = values[0]
                    # print('Value: ', value, '   Action: ', action)

                img, reward, done, _ = env.step(action, j)
                total_rewards[j] += reward
                # img = get_img(env)
                img_hist[j].append(img[j])
                imsave(
                    os.path.join(frames_dir[j],
                                 'frame_{:04d}.png'.format(tidx)),
                    resize(img[j], (img_shape[0], img_shape[1], 3)))
            # print(tidx, '\tAction: ', action, '\tValues: ', value, '\tRewards: ', reward, '\tTotal rewards: ', total_rewards)#, flush=True)
            if done:
                # print('Faultered at tidx: ', tidx)
                for j in range(numAgents):
                    rwd[j].append(total_rewards[j])
                    percent_exp[j].append(env.env.percent_explored[j])
                # env.env, img = env.reset()
                break
    for i in range(numAgents):
        print('-----------------------------------')
        print('Agent ' + str(i))
        print('Iteration: ', iters)
        avg_rwd = sum(rwd[i]) / iters
        avg_pct_exp = sum(percent_exp[i]) / iters
        med_pct_exp = statistics.median(percent_exp[i])
        print('Average Reward: ', avg_rwd)
        print('Average Percent Explored: ', avg_pct_exp, '%')
        print('Median Percent Explored: ', med_pct_exp)
        print('-----------------------------------')
예제 #8
0
def test(env_id, policy_name, seed, nstack=1, numAgents=2, benchmark=False):
    iters = 100
    rwd = []
    percent_exp = []
    env = EnvVec([
        make_env(env_id, benchmark=benchmark, rank=idx, seed=seed)
        for idx in range(1)
    ],
                 particleEnv=True,
                 test=True)
    # print(env_id)
    # print("logger dir: ", logger.get_dir())
    # env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir()))
    if env_id == 'Pendulum-v0':
        if continuous_actions:
            env.action_space.n = env.action_space.shape[0]
        else:
            env.action_space.n = 10
    gym.logger.setLevel(logging.WARN)
    ob_space = env.observation_space
    ac_space = env.action_space

    # def get_img(env):
    #     ax, img = env.get_img()
    #    return ax, img

    # def process_img(img):
    #     img = rgb2grey(copy.deepcopy(img))
    #    img = resize(img, img_shape)
    #    return img

    policy_fn = policy_fn_name(policy_name)

    nsteps = 5
    total_timesteps = int(80e6)
    vf_coef = 0.9
    ent_coef = 0.01
    max_grad_norm = 0.5
    lr = 7e-4
    lrschedule = 'linear'
    epsilon = 1e-5
    alpha = 0.99
    continuous_actions = False
    debug = False
    if numAgents == 1:
        model = Model(policy=policy_fn,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      nenvs=1,
                      nsteps=nsteps,
                      nstack=nstack,
                      num_procs=1,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      max_grad_norm=max_grad_norm,
                      lr=lr,
                      alpha=alpha,
                      epsilon=epsilon,
                      total_timesteps=total_timesteps,
                      lrschedule=lrschedule,
                      continuous_actions=continuous_actions,
                      debug=debug)
        m_name = 'test_model_Mar7_1mil.pkl'
        model.load(m_name)
    else:
        model = []
        for i in range(numAgents):
            model.append(
                Model(policy=policy_fn,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      nenvs=1,
                      nsteps=nsteps,
                      nstack=nstack,
                      num_procs=1,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      max_grad_norm=max_grad_norm,
                      lr=lr,
                      alpha=alpha,
                      epsilon=epsilon,
                      total_timesteps=total_timesteps,
                      lrschedule=lrschedule,
                      continuous_actions=continuous_actions,
                      debug=debug,
                      itr=i,
                      particleEnv=True))
        for i in range(numAgents):
            m_name = 'partEnv_model_' + str(i) + '.pkl'
            model[i].load(m_name)
            print('---------------------------------------------')
            print("Successfully Loaded: ", m_name)
            print('---------------------------------------------')

    obs = env.reset()
    states = [[], []]
    dones = [False, False]
    rwd = [[], []]
    percent_exp = [[], []]
    for i in range(1, iters + 1):
        if i % 1 == 0:
            for j in range(numAgents):
                print('-----------------------------------')
                print('Agent ' + str(j))
                print('Iteration: ', i)
                avg_rwd = sum(rwd[j]) / i
                # avg_pct_exp = sum(percent_exp[j])/i
                # med_pct_exp = statistics.median(percent_exp[j])
                print('Average Reward: ', avg_rwd)
                # print('Average Percent Explored: ', avg_pct_exp, '%')
                # print('Median Percent Explored: ', med_pct_exp)
                print('-----------------------------------')
        actions = [[], []]
        values = [[], []]
        total_rewards = [[0], [0]]
        nstack = 1
        for tidx in range(1000):
            # if tidx % nstack == 0:
            for j in range(numAgents):
                # if tidx > 0:
                # input_imgs = np.expand_dims(np.squeeze(np.stack(img_hist, -1)), 0)
                # print(np.shape(input_imgs))
                # plt.imshow(input_imgs[0, :, :, 0])
                # plt.imshow(input_imgs[0, :, :, 1])
                # plt.draw()
                # plt.pause(0.000001)
                # print(obs[:, j])
                # print(states[j])
                # print(dones)
                # actions[j], values[j], states[j] = model[j].step(obs[:, j].reshape(1, 21), states[j], dones[j])
                ob_shape = np.asarray([
                    env.observation_space[i].shape for i in range(env.n)
                ]).flatten()
                print(ob_shape)
                actions[j], values[j], states[j] = model[j].step(
                    obs[:, j].reshape(1, ob_shape[1]), states[j], dones[j])
                # action = actions[0]
                # value = values[0]

            obs, rewards, dones, _ = env.step(actions)
            dones = dones.flatten()
            total_rewards += rewards  # wrong?
            print(total_rewards)
            # print(dones)
            # img = get_img(env)
            # obs_hist[j].append(img[j])
            # imsave(os.path.join(frames_dir[j], 'frame_{:04d}.png'.format(tidx)), resize(img[j], (img_shape[0], img_shape[1], 3)))
            # print(tidx, '\tAction: ', action, '\tValues: ', value, '\tRewards: ', reward, '\tTotal rewards: ', total_rewards)#, flush=True)
            if True in dones:
                # print('Faultered at tidx: ', tidx)
                for j in range(numAgents):
                    rwd[j].append(total_rewards[j])
                    # percent_exp[j].append(env.env.percent_explored[j])
                obs = env.reset()
                break
    for i in range(numAgents):
        print('-----------------------------------')
        print('Agent ' + str(i))
        print('Iteration: ', iters)
        avg_rwd = sum(rwd[i]) / iters
        # avg_pct_exp = sum(percent_exp[i])/iters
        # med_pct_exp = statistics.median(percent_exp[i])
        print('Average Reward: ', avg_rwd)
        # print('Average Percent Explored: ', avg_pct_exp, '%')
        # print('Median Percent Explored: ', med_pct_exp)
        print('-----------------------------------')
예제 #9
0
def main(visualize=False):
    session = tf_util.make_session()
    env_model = EnvNetwork(action_space_size=6,
                           nbatch=num_env * singlestep,
                           K=K,
                           nsteps=singlestep,
                           reuse=False,
                           session=session)
    session.run(tf.global_variables_initializer())
    env_model.restore()

    env = VecFrameStack(make_doom_env(num_env, seed, 'mixed'), 4)
    navi_model = Model(policy=CnnPolicy,
                       ob_space=env.observation_space,
                       ac_space=Discrete(3),
                       nenvs=num_env,
                       nsteps=nsteps,
                       ent_coef=0.01,
                       vf_coef=0.5,
                       max_grad_norm=0.5,
                       lr=7e-4,
                       alpha=0.99,
                       epsilon=1e-5,
                       total_timesteps=total_timesteps,
                       lrschedule='linear',
                       model_name='navi')
    navi_model.load("O:\\Doom\\baselinemodel\\navigate_flat2.dat")

    fire_model = Model(policy=CnnPolicy,
                       ob_space=env.observation_space,
                       ac_space=Discrete(3),
                       nenvs=num_env,
                       nsteps=nsteps,
                       ent_coef=0.01,
                       vf_coef=0.5,
                       max_grad_norm=0.5,
                       lr=7e-4,
                       alpha=0.99,
                       epsilon=1e-5,
                       total_timesteps=total_timesteps,
                       lrschedule='linear',
                       model_name='fire')

    fire_model.load("O:\\Doom\\baselinemodel\\fire_flat2.dat")
    policy_model = MixedModel(navi_model, fire_model, check_enemy_leave,
                              check_enemy_enter, [0, 1, 4], [0, 1, 5])
    runner = Runner(env, policy_model, nsteps=nsteps, gamma=0.99)

    nh, nw, nc = env.observation_space.shape

    while True:
        total_loss = 0
        for _ in tqdm(range(save_freq)):
            obs1, _, _, mask1, actions1, _ = runner.run()

            obs1 = np.reshape(obs1, [num_env, nsteps, nh, nw, nc])
            obs1 = obs1[:, :, :, :, -1:]

            actions1 = np.reshape(actions1, [num_env, nsteps])
            mask1 = np.reshape(mask1, [num_env, nsteps])

            hidden_states = env_model.initial_state
            for s in range(0, nsteps - K - singlestep, singlestep):
                input_frames = obs1[:,
                                    s:s + singlestep, :, :, :] // norm_factor
                input_frames = np.reshape(input_frames,
                                          [num_env * singlestep, nh, nw])
                input_frames = np.eye(9)[input_frames]
                actions, masks, expected_observations = [], [], []
                for t in range(K):
                    expected_observation = obs1[:, s + t + 1:s + singlestep +
                                                t + 1, :, :, :]
                    expected_observation = np.reshape(
                        expected_observation,
                        [num_env * singlestep, nh, nw, 1])
                    expected_observations.append(expected_observation)

                    action = actions1[:, s + t:s + singlestep + t]
                    action = np.reshape(action, [num_env * singlestep])
                    actions.append(action)

                    mask = mask1[:, s + t:s + singlestep + t]
                    mask = np.reshape(mask, [num_env * singlestep])
                    masks.append(mask)

                if s > 0:
                    loss, prediction, hidden_states = env_model.train_and_predict(
                        input_frames, actions, masks, expected_observations,
                        hidden_states)
                    total_loss += loss
                else:
                    # warm up
                    prediction, hidden_states = env_model.predict(
                        input_frames, actions, masks, hidden_states)

                if visualize and s == 3 * singlestep:
                    for batch_idx in range(num_env * singlestep):
                        expected_t = expected_observations[0]
                        if np.sum(expected_t[batch_idx, :, :, :] > 0.0):
                            input_frame = input_frames[batch_idx, :, :, :]
                            cv2.imshow('input', input_frame)
                            for i in range(K):
                                time_t_expectation = expected_observations[i]
                                exp_obs = time_t_expectation[
                                    batch_idx, :, :, :]
                                cv2.imshow('expected for t+{}'.format(i + 1),
                                           exp_obs)
                            for i in range(K):
                                time_t_prediction = prediction[i]
                                cv2.imshow(
                                    'prediction for t+{}'.format(i + 1),
                                    time_t_prediction[batch_idx, :, :, 7])
                            cv2.waitKey(0)

        print("avg_loss = {}".format(total_loss / K / save_freq /
                                     valid_batch_size))
        env_model.save()