Ejemplo n.º 1
0
def train(env_id, num_timesteps, seed, policy, lrschedule, num_env):
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = LstmPolicy
    elif policy == 'lnlstm':
        policy_fn = LnLstmPolicy
    env = VecFrameStack(make_atari_env(env_id, num_env, seed), NUM_CPU)
    learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
    env.close()
Ejemplo n.º 2
0
def train(env_id, num_timesteps, seed, num_cpu):
    env = VecFrameStack(make_atari_env(env_id, NUM_CPU, seed), NUM_ENV)
    policy_fn = partial(CnnPolicy)  #, one_dim_bias=True)
    learn(policy_fn,
          env,
          seed,
          total_timesteps=int(num_timesteps * 1.1),
          nprocs=num_cpu,
          log_interval=20,
          save_interval=1000)
    env.close()
def train(*, env_id, num_env, hps, num_timesteps, seed):
    venv = VecFrameStack(
        make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True
    # venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy,
              'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
                scope='pol',
                ob_space=ob_space,
                ac_space=ac_space,
                update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'),
                proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'),
                dynamics_bonus = hps.pop("dynamics_bonus")
            ),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
    )
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128*50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()

    return agent
Ejemplo n.º 4
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    venv.score_multiple = 1
    venv.record_obs = False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]

    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            exploration_type=hps.pop("exploration_type"),
            beta=hps.pop("beta"),
        ),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        noise_type=hps.pop('noise_type'),
        noise_p=hps.pop('noise_p'),
        use_sched=hps.pop('use_sched'),
        num_env=num_env,
        exp_name=hps.pop('exp_name'),
    )
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        n_updates = 0

        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()

            if NSML:
                n_updates = int(info['update']['n_updates'])
                nsml_dict = {
                    k: np.float64(v)
                    for k, v in info['update'].items()
                    if isinstance(v, Number)
                }
                nsml.report(step=n_updates, **nsml_dict)

            counter += 1
        #if n_updates >= 40*1000: # 40K updates
        #    break
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
Ejemplo n.º 5
0
def train(*, env_id, num_env, hps, num_timesteps, seed):

    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))

    # Size of states when stored in the memory.
    only_train_r = hps.pop('only_train_r')

    online_r_training = hps.pop('online_train_r') or only_train_r

    r_network_trainer = None
    save_path = hps.pop('save_path')
    r_network_weights_path = hps.pop('r_path')
    '''
    ec_type = 'none' # hps.pop('ec_type')

    venv = CuriosityEnvWrapperFrameStack(
        make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        vec_episodic_memory = None,
        observation_embedding_fn = None,
        exploration_reward = ec_type,
        exploration_reward_min_step = 0,
        nstack = hps.pop('frame_stack'),
        only_train_r = only_train_r
        )
    '''

    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')

    log_interval = hps.pop('log_interval')

    nminibatches = hps.pop('nminibatches')

    play = hps.pop('play')

    if play:
        nsteps = 1

    rnd_type = hps.pop('rnd_type')
    div_type = hps.pop('div_type')

    num_agents = hps.pop('num_agents')

    load_ram = hps.pop('load_ram')

    debug = hps.pop('debug')

    rnd_mask_prob = hps.pop('rnd_mask_prob')

    rnd_mask_type = hps.pop('rnd_mask_type')

    indep_rnd = hps.pop('indep_rnd')
    logger.info("indep_rnd:", indep_rnd)
    indep_policy = hps.pop('indep_policy')

    sd_type = hps.pop('sd_type')

    from_scratch = hps.pop('from_scratch')

    use_kl = hps.pop('use_kl')

    save_interval = 100

    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus"),
            num_agents=num_agents,
            rnd_type=rnd_type,
            div_type=div_type,
            indep_rnd=indep_rnd,
            indep_policy=indep_policy,
            sd_type=sd_type,
            rnd_mask_prob=rnd_mask_prob),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        gamma_div=hps.pop('gamma_div'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=nminibatches,
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=5 if debug else 128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        log_interval=log_interval,
        only_train_r=only_train_r,
        rnd_type=rnd_type,
        reset=hps.pop('reset'),
        dynamics_sample=hps.pop('dynamics_sample'),
        save_path=save_path,
        num_agents=num_agents,
        div_type=div_type,
        load_ram=load_ram,
        debug=debug,
        rnd_mask_prob=rnd_mask_prob,
        rnd_mask_type=rnd_mask_type,
        sd_type=sd_type,
        from_scratch=from_scratch,
        use_kl=use_kl,
        indep_rnd=indep_rnd)

    load_path = hps.pop('load_path')
    base_load_path = hps.pop('base_load_path')

    agent.start_interaction([venv])
    if load_path is not None:

        if play:
            agent.load(load_path)
        else:
            #agent.load(load_path)
            #agent.load_help_info(0, load_path)
            #agent.load_help_info(1, load_path)

            #load diversity agent
            #base_agent_idx = 1
            #logger.info("load base  agents weights from {}  agent {}".format(base_load_path, str(base_agent_idx)))
            #agent.load_agent(base_agent_idx, base_load_path)
            #agent.clone_baseline_agent(base_agent_idx)
            #agent.load_help_info(0, dagent_load_path)
            #agent.clone_agent(0)

            #load main agen1
            src_agent_idx = 1

            logger.info("load main agent weights from {} agent {}".format(
                load_path, str(src_agent_idx)))
            agent.load_agent(src_agent_idx, load_path)

            if indep_rnd == False:
                rnd_agent_idx = 1
            else:
                rnd_agent_idx = src_agent_idx
            #rnd_agent_idx = 0
            logger.info("load rnd weights from {} agent {}".format(
                load_path, str(rnd_agent_idx)))
            agent.load_rnd(rnd_agent_idx, load_path)
            agent.clone_agent(rnd_agent_idx,
                              rnd=True,
                              policy=False,
                              help_info=False)

            logger.info("load help info from {} agent {}".format(
                load_path, str(src_agent_idx)))
            agent.load_help_info(src_agent_idx, load_path)

            agent.clone_agent(src_agent_idx,
                              rnd=False,
                              policy=True,
                              help_info=True)

            #logger.info("load main agent weights from {} agent {}".format(load_path, str(2)))

            #load_path = '/data/xupeifrom7700_1000/seed1_log0.5_clip-0.5~0.5_3agent_hasint4_2divrew_-1~1/models'
            #agent.load_agent(1, load_path)

            #agent.clone_baseline_agent()
            #if sd_type =='sd':
            #    agent.load_sd("save_dir/models_sd_trained")

        #agent.initialize_discriminator()

    update_ob_stats_from_random_agent = hps.pop(
        'update_ob_stats_from_random_agent')
    if play == False:

        if load_path is not None:
            pass  #agent.collect_statistics_from_model()
        else:
            if update_ob_stats_from_random_agent and rnd_type == 'rnd':
                agent.collect_random_statistics(num_timesteps=128 *
                                                5 if debug else 128 * 50)
        assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

        #agent.collect_rnd_info(128*50)
        '''
        if sd_type=='sd':
            agent.train_sd(max_nepoch=300, max_neps=5)
            path = '{}_sd_trained'.format(save_path)
            logger.log("save model:",path)
            agent.save(path)

            return
            #agent.update_diverse_agent(max_nepoch=1000)
            #path = '{}_divupdated'.format(save_path)
            #logger.log("save model:",path)
            #agent.save(path)
        '''
        counter = 0
        while True:
            info = agent.step()

            n_updates = agent.I.stats["n_updates"]
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
                counter += 1

            if info['update'] and save_path is not None and (
                    n_updates % save_interval == 0 or n_updates == 1):
                path = '{}_{}'.format(save_path, str(n_updates))
                logger.log("save model:", path)
                agent.save(path)
                agent.save_help_info(save_path, n_updates)

            if agent.I.stats['tcount'] > num_timesteps:
                path = '{}_{}'.format(save_path, str(n_updates))
                logger.log("save model:", path)
                agent.save(path)
                agent.save_help_info(save_path, n_updates)
                break
        agent.stop_interaction()
    else:
        '''
        check_point_rews_list_path ='{}_rewslist'.format(load_path)
        check_point_rnd_path ='{}_rnd'.format(load_path)
        oracle_rnd = oracle.OracleExplorationRewardForAllEpisodes()
        oracle_rnd.load(check_point_rnd_path)
        #print(oracle_rnd._collected_positions_writer)
        #print(oracle_rnd._collected_positions_reader)

        rews_list = load_rews_list(check_point_rews_list_path)
        print(rews_list)
        '''

        istate = agent.stochpol.initial_state(1)
        #ph_mean, ph_std = agent.stochpol.get_ph_mean_std()

        last_obs, prevrews, ec_rews, news, infos, ram_states, _ = agent.env_get(
            0)
        agent.I.step_count += 1

        flag = False
        show_cam = True

        last_xr = 0

        restore = None
        '''
        #path = 'ram_state_500_7room'
        #path='ram_state_400_6room'
        #path='ram_state_6700' 
        path='ram_state_7700_10room'
        f = open(path,'rb')
        restore = pickle.load(f)
        f.close()
        last_obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore,0)
        print(last_obs.shape)

        #path = 'ram_state_400_monitor_rews_6room'
        #path = 'ram_state_500_monitor_rews_7room'
        #path='ram_state_6700_monitor_rews'
        path='ram_state_7700_monitor_rews_10room'
        f = open(path,'rb')
        monitor_rews = pickle.load(f)
        f.close()
        
        agent.I.venvs[0].set_cur_monitor_rewards_by_idx(monitor_rews,0)
        '''

        agent_idx = np.asarray([0])
        sample_agent_prob = np.asarray([0.5])

        ph_mean = agent.stochpol.ob_rms_list[0].mean
        ph_std = agent.stochpol.ob_rms_list[0].var**0.5

        buf_ph_mean = np.zeros(
            ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]),
            np.float32)
        buf_ph_std = np.zeros(
            ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]),
            np.float32)

        buf_ph_mean[0, 0] = ph_mean
        buf_ph_std[0, 0] = ph_std

        vpreds_ext_list = []

        ep_rews = np.zeros((1))
        divexp_flag = False
        step_count = 0
        stage_prob = True

        last_rew_ob = np.full_like(last_obs, 128)

        clusters = Clusters(1.0)

        #path = '{}_sd_rms'.format(load_path)
        #agent.I.sd_rms.load(path)

        while True:

            dict_obs = agent.stochpol.ensure_observation_is_dict(last_obs)

            #acs= np.random.randint(low=0, high=15, size=(1))
            acs, vpreds_int, vpreds_ext, nlps, istate, ent = agent.stochpol.call(
                dict_obs, news, istate, agent_idx[:, None])

            step_acs = acs
            t = ''
            #if show_cam==True:
            t = input("input:")
            if t != '':
                t = int(t)
                if t <= 17:
                    step_acs = [t]

            agent.env_step(0, step_acs)

            obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = agent.env_get(
                0)

            if news[0] and restore is not None:
                obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore, 0)
                agent.I.venvs[0].set_cur_monitor_rewards_by_idx(
                    monitor_rews, 0)

            ep_rews = ep_rews + prevrews

            print(ep_rews)

            last_rew_ob[prevrews > 0] = obs[prevrews > 0]

            room = infos[0]['position'][2]
            vpreds_ext_list.append([vpreds_ext, room])
            #print(monitor_rews[0])
            #print(len(monitor_rews[0]))
            #print(infos[0]['open_door_type'])

            stack_obs = np.concatenate([last_obs[:, None], obs[:, None]], 1)

            fd = {}

            fd[agent.stochpol.ph_ob[None]] = stack_obs

            fd.update({
                agent.stochpol.sep_ph_mean: buf_ph_mean,
                agent.stochpol.sep_ph_std: buf_ph_std
            })
            fd[agent.stochpol.ph_agent_idx] = agent_idx[:, None]
            fd[agent.stochpol.sample_agent_prob] = sample_agent_prob[:, None]

            fd[agent.stochpol.last_rew_ob] = last_rew_ob[:, None]
            fd[agent.stochpol.game_score] = ep_rews[:, None]

            fd[agent.stochpol.sd_ph_mean] = agent.I.sd_rms.mean
            fd[agent.stochpol.sd_ph_std] = agent.I.sd_rms.var**0.5

            div_prob = 0

            all_div_prob = tf_util.get_session().run(
                [agent.stochpol.all_div_prob], fd)
            '''
            if prevrews[0] > 0:
                clusters.update(rnd_em,room)
    
                num_clusters = len(clusters._cluster_list)
                for i in range(num_clusters):
                    print("{} {}".format(str(i),list(clusters._room_set[i])))
            '''
            print("vpreds_int: ", vpreds_int, "vpreds_ext:", vpreds_ext,
                  "ent:", ent, "all_div_prob:", all_div_prob, "room:", room,
                  "step_count:", step_count)

            #aaaa = np.asarray(vpreds_ext_list)
            #print(aaaa[-100:])
            '''
            if step_acs[0]==0:
                ram_state = ram_states[0]
                path='ram_state_7700_10room'
                f = open(path,'wb')
                pickle.dump(ram_state,f)
                f.close()

                path='ram_state_7700_monitor_rews_10room'
                f = open(path,'wb')
                pickle.dump(monitor_rews[0],f)
                f.close()
            '''
            '''
            if  restore is None:
                restore = ram_states[0]

            
            if np.random.rand() < 0.1:
                print("restore")
                obs = agent.I.venvs[0].restore_full_state_by_idx(restore,0)
                prevrews = None
                ec_rews = None
                news= True
                infos = {}
                ram_states = ram_states[0]

                #restore = ram_states[0]
            '''

            img = agent.I.venvs[0].render()

            last_obs = obs

            step_count = step_count + 1

            time.sleep(0.04)
def load_test(*, env_id, num_env, hps, num_timesteps, seed, fname):
    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        obs_save_flag=True)

    tf_util.load_state("saved_states/save1")
    agent.start_interaction([venv])
    counter = 0
    while True:
        info = agent.step()
        if agent.I.stats['epcount'] > 1:
            with open("obs_acs.pickle", 'wb') as f1:
                pickle.dump(agent.obs_rec, f1)
            break
Ejemplo n.º 7
0
def record_():
    model_path = args.load_model
    os.path.isfile(model_path)

    # search skills

    m = re.search("\[[0-9\, \[\]]*\]", model_path)
    if m is None:
        raise ValueError(
            "load_model: {} does not contain skills".format(model_path))
    skills = str_to_skills(m.group(0))

    # search env-id
    env_id_list = ENV_LIST
    env_id = None
    searched = False
    m = re.search("[A-Z][a-z]*NoFrameskip-v4", model_path)
    if m is not None:
        searched = True
        env_id = m.group(0)

    if searched is not True:
        for id_ in env_id_list:
            if id_.lower() in model_path.lower():
                searched = True
                env_id = id_ + "NoFrameskip-v4"

    if searched is not True:
        raise ValueError(
            "load_model: {} does not contain env id".format(model_path))

    save_path = args.logdir
    if save_path is None:
        save_path = os.path.dirname(model_path)

    print("ENV:{} \nskills:{} \nmodel_path:{} \nsave_path:{}\n".format(
        env_id, skills, model_path, save_path))
    time.sleep(3)

    env_creator_ = lambda env: ActionRemapWrapper(env)
    env_creator = lambda env: SkillWrapper(env_creator_(env), skills=skills)
    env = VecFrameStack(
        make_atari_env(env_id,
                       1,
                       args.seed,
                       extra_wrapper_func=env_creator,
                       logdir=save_path,
                       wrapper_kwargs={
                           "episode_life": False,
                           "clip_rewards": False
                       }), 4)

    if args.load_model is None:
        raise NotImplementedError
    assert os.path.isfile(args.load_model)

    if args.rl_model == "ppo":
        model = PPO2.load(args.load_model)
    elif args.rl_model == "a2c":
        model = A2C.load(args.load_model)
    elif args.rl_model is None:
        if "ppo" in model_path:
            model = PPO2.load(model_path)
        elif "a2c" in model_path:
            model = A2C.load(model_path)
        else:
            raise ValueError("please specify rl_model")
    else:
        raise ValueError("{} rl_model not recognize".format(args.rl_model))

    # DEBUG
    set_global_seeds(args.seed)

    obs = env.reset()
    if args.record:
        env = VecVideoRecorder(env,
                               save_path,
                               record_video_trigger=lambda x: x == 0,
                               video_length=MAX_VIDEO_LENGTH)
        env.reset()
    total_rewards = 0

    action_save_path = os.path.join(save_path, "history_action.txt")
    if args.log_action:
        try:
            os.remove(action_save_path)
        except OSError as e:
            if e.errno != errno.ENOENT:  # errno.ENOENT = no such file or directory
                raise  # re-raise exception if a different error occurred
    log_picture = None
    if args.log_picture:
        log_picture = os.path.join(save_path, "history_action_pic")
        log_picture = mkdirs(log_picture, mode="keep")
        action_save_path = os.path.join(log_picture,
                                        os.path.basename(action_save_path))
        # try:
        #     # shutil.rmtree()
        # except:

    print("start evaluate")
    with open(action_save_path, 'a') as f:
        for steps in range(args.eval_max_steps):
            action, _states = model.predict(obs)
            if args.log_action:
                # print("{}".format(action[0]), sep=" ", file=f)
                f.write("{} ".format(action[0]))
            if args.log_picture:
                assert log_picture is not None
                pict = env.render(mode='rgb_array')

                im = Image.fromarray(pict)
                _path = os.path.join(log_picture,
                                     "{}_{}.jpg".format(steps, action[0]))
                im.save(_path)
            obs, rewards, dones, info = env.step(action)
            total_rewards += rewards
            if bool(dones[0]) is True:
                break
    print("steps: {}/{}".format(steps + 1, args.eval_max_steps))
    print("total_rewards: {}".format(total_rewards))
    env.close()
Ejemplo n.º 8
0
def train(*,
          env_id,
          num_env,
          hps,
          num_timesteps,
          seed,
          test=False,
          e_greedy=0,
          **kwargs):
    # import pdb; pdb.set_trace()
    print('_________________________________________________________')
    pprint.pprint(f'hyperparams: {hps}', width=1)
    pprint.pprint(f'additional hyperparams: {kwargs}', width=1)
    print('_________________________________________________________')
    # TODO: this is just for debugging
    # tmp_env = make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(clip_rewards=kwargs['clip_rewards']))
    venv = VecFrameStack(
        make_atari_env(
            env_id,
            num_env,
            seed,
            wrapper_kwargs=dict(clip_rewards=kwargs['clip_rewards']),
            start_index=num_env * MPI.COMM_WORLD.Get_rank(),
            max_episode_steps=hps.pop('max_episode_steps'),
            action_space=kwargs['action_space']), hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1  # TODO: understand what is score multiple
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
    )
    saver = tf.train.Saver()
    if test:
        agent.restore_model(saver,
                            kwargs['load_dir'],
                            kwargs['exp_name'],
                            mtype=kwargs['load_mtype'])
    agent.start_interaction([venv])

    import time
    st = time.time()
    if hps.pop('update_ob_stats_from_random_agent'):
        if os.path.exists('./data/ob_rms.pkl'):
            with open('./data/ob_rms.pkl', 'rb') as handle:
                agent.stochpol.ob_rms.mean, agent.stochpol.ob_rms.var, agent.stochpol.ob_rms.count = pickle.load(
                    handle)

        else:
            ob_rms = agent.collect_random_statistics(
                num_timesteps=128 * 50)  # original number128*50
            with open('./data/ob_rms.pkl', 'wb') as handle:
                pickle.dump([ob_rms.mean, ob_rms.var, ob_rms.count],
                            handle,
                            protocol=2)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())
    print(f'Time duration {time.time() - st}')

    if not test:
        counter = 1
        while True:
            info = agent.step()
            if info['update']:
                # import pdb; pdb.set_trace()
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if agent.I.stats['tcount'] // 1e7 == counter:
                agent.save_model(saver,
                                 kwargs['save_dir'],
                                 kwargs['exp_name'],
                                 mtype=str(counter * 1e7))
                counter += 1
            if agent.I.stats['tcount'] > num_timesteps:
                break
    if test:
        for pkls in range(0, 10):
            print('collecting the', pkls, 'th pickle', ' ' * 20)
            all_rollout = agent.evaluate(env_id=env_id,
                                         seed=seed,
                                         episodes=1000,
                                         save_image=kwargs['save_image'],
                                         e_greedy=e_greedy)
            with open('./data/newworld21_' + str(pkls).zfill(1) + '.pkl',
                      'wb') as handle:
                pickle.dump(all_rollout, handle)
    agent.stop_interaction()
    if not test:
        agent.save_model(saver, kwargs['save_dir'], kwargs['exp_name'])
Ejemplo n.º 9
0
sys.path.append(os.path.abspath("./stable-baselines/stable_baselines"))
sys.path.append(os.path.abspath("../../lib"))

from env_wrapper import SkillWrapper, ActionRemapWrapper
from stable_baselines.common.vec_env import VecFrameStack
from cmd_util import make_atari_env


ENV_ID = "SeaquestNoFrameskip-v4"
SEED = 1000
TEMP_LOGDIR = "./"
env_id = ENV_ID
skills = [[2,2,2],[1,1,1]]
env_creator_ = lambda env:ActionRemapWrapper(env)
env_creator = lambda env:SkillWrapper(env_creator_(env), skills=skills)
env = VecFrameStack(make_atari_env(env_id, 1, SEED, extra_wrapper_func=env_creator, logdir=TEMP_LOGDIR, wrapper_kwargs={"episode_life":False, "clip_rewards":False}), 4)


model = Macro(PPO2, CnnPolicy, env, verbose=1, macro_length=3, macro_num=None)

# model.act_model.set_clip_action_mask([1,1,1,1,1,0,0,0])
obs = env.reset()

model.learn(100000, eval_env=env.envs[0], eval_timesteps=2000, timesteps_per_epoch=1000)
for steps in range(5):
    action, _states = model.predict(obs)
    
    obs, rewards, dones, info = env.step(action)
    print(action)
    if bool(dones[0]) is True:
        break
Ejemplo n.º 10
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    experiment = os.environ.get('EXPERIMENT_LVL')
    if experiment == 'ego':
        # for the ego experiment we needed a higher intrinsic coefficient
        hps['int_coeff'] = 3.0

    hyperparams = copy(hps)
    hyperparams.update({'seed': seed})
    logger.info("Hyperparameters:")
    logger.info(hyperparams)

    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs={},
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    venv.score_multiple = {
        'Mario': 500,
        'MontezumaRevengeNoFrameskip-v4': 1,
        'GravitarNoFrameskip-v4': 250,
        'PrivateEyeNoFrameskip-v4': 500,
        'SolarisNoFrameskip-v4': None,
        'VentureNoFrameskip-v4': 200,
        'PitfallNoFrameskip-v4': 100,
    }[env_id]

    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space

    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]

    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        restore_model_path=hps.pop('restore_model_path'))
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)

    save_model = hps.pop('save_model')
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    #profiler = cProfile.Profile()
    #profiler.enable()
    #tracemalloc.start()
    #prev_snap = tracemalloc.take_snapshot()
    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1

            # if (counter % 10) == 0:
            #     snapshot = tracemalloc.take_snapshot()
            #     top_stats = snapshot.compare_to(prev_snap, 'lineno')
            #     for stat in top_stats[:10]:
            #         print(stat)
            #     prev_snap = snapshot
            #     profiler.dump_stats("profile_rnd")

            if (counter % 100) == 0 and save_model:
                agent.save_model(agent.I.step_count)

        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
Ejemplo n.º 11
0
    def get_rewards(self,
                    skills=[],
                    train_total_timesteps=5000000,
                    eval_times=100,
                    eval_max_steps=int(1e6),
                    model_save_name=None,
                    add_info={}):
        """
        
        :param skills: (list) the availiable action sequence for agent 
        e.g [[0,2,2],[0,1,1]]
        :param train_total_timesteps: (int)total_timesteps to train 
        :param eval_times: (int)the evaluation times
        e.g eval_times=100, evalulate the policy by averageing the reward of 100 episode
        :param eval_max_steps: (int)maximum timesteps per episode when evaluate
        (deprecate):param model_save_name: (str)specify the name of saved model (should not repeat)
        :param add_info: (dict) other information to log in log.txt
        """

        if self.save_tensorboard and self.save_path is not None:
            tensorboard_log = os.path.join(self.save_path,
                                           "model_" + str(self._serial_num))
        else:
            tensorboard_log = None

        env_creator = lambda env: SkillWrapper(
            self.env_creator(env), skills=skills, gamma=self.gamma)

        if self.save_monitor is True:
            monitor_path = os.path.join(self.save_path, "monitor")
            try:
                os.makedirs(monitor_path)
            except OSError as ex:
                if ex.errno == errno.EEXIST and os.path.exists(monitor_path):
                    print("{} exists. ignore".format(monitor_path))
                    pass
                else:
                    raise
        else:
            monitor_path = None

        if "cfg" in self.env_id:

            env = make_doom_env(self.env_id,
                                self.num_cpu,
                                self.seed,
                                extra_wrapper_func=env_creator,
                                logdir=monitor_path)

        else:
            env = VecFrameStack(
                make_atari_env(self.env_id,
                               self.num_cpu,
                               self.seed,
                               extra_wrapper_func=env_creator,
                               logdir=monitor_path), 4)

        model = None
        if self.use_converge_parameter is True:
            model = self.model(self.policy,
                               env,
                               verbose=self.verbose,
                               tensorboard_log=tensorboard_log,
                               n_steps=128,
                               nminibatches=4,
                               lam=0.95,
                               gamma=0.99,
                               noptepochs=4,
                               ent_coef=.01,
                               learning_rate=lambda f: f * 2.5e-4,
                               cliprange=lambda f: f * 0.1)
        else:
            model = self.model(self.policy,
                               env,
                               verbose=self.verbose,
                               tensorboard_log=tensorboard_log)

        self.strat_time = time.time()
        print("start to train agent...")

        callback = None
        if self.evaluate_freq is not None and self.evaluate_freq > 0:
            preiod_eval_path = os.path.join(self.save_path, "period_eval")
            mkdirs(preiod_eval_path)
            if "cfg" in self.env_id:

                eval_env = make_doom_env(self.env_id,
                                         self.num_cpu,
                                         self.seed,
                                         extra_wrapper_func=env_creator,
                                         logdir=monitor_path,
                                         wrapper_kwargs={
                                             "episode_life": False,
                                             "clip_rewards": False
                                         })
            else:
                eval_env = VecFrameStack(
                    make_atari_env(self.env_id,
                                   self.num_cpu,
                                   self.seed,
                                   extra_wrapper_func=env_creator,
                                   logdir=preiod_eval_path,
                                   wrapper_kwargs={
                                       "episode_life": False,
                                       "clip_rewards": False
                                   }), 4)
            callback = self.eval_callback(eval_env,
                                          freq=self.evaluate_freq,
                                          eval_times=eval_times,
                                          eval_max_steps=eval_max_steps,
                                          save_path=preiod_eval_path)

        model.learn(total_timesteps=train_total_timesteps,
                    reset_num_timesteps=self.reset_num_timesteps,
                    callback=callback)
        print("Finish train agent")

        #evaluate once more because sometimes it is not divisible
        if callback is not None:
            callback({"self": model, "eval_now": True}, None)

        if self.save_path is not None:
            if self.preserve_model > 0:

                self.save_model(model, skills=skills)

        env.close()
        # evaluate
        env = VecFrameStack(
            make_atari_env(self.env_id,
                           self.num_cpu,
                           self.seed,
                           extra_wrapper_func=env_creator,
                           logdir=None), 4)
        info = self.evaluate(env, model, eval_times, eval_max_steps)
        try:
            env.close()
        except AttributeError as e:
            print("Ignore : {}".format(e))
        try:
            del model
        except AttributeError as e:
            print("Ignore del model : {}".format(e))

        #log result
        info.update(add_info)
        self.log(info)

        self._serial_num = self._serial_num + 1
        return info["ave_score"], info["ave_action_reward"]