Beispiel #1
0
def train(*, env_id, num_env, hps, num_timesteps, seed, use_reward, ep_path,
          dmlab):
    venv = VecFrameStack(
        DummyVecEnv([
            lambda: CollectGymDataset(Minecraft('MineRLTreechop-v0', 'treechop'
                                                ),
                                      ep_path,
                                      atari=False)
        ]), hps.pop('frame_stack'))
    venv.score_multiple = 1
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    print('Running train ==========================================')
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'))
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
Beispiel #2
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    venv.score_multiple = 1
    venv.record_obs = False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]

    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            exploration_type=hps.pop("exploration_type"),
            beta=hps.pop("beta"),
        ),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        noise_type=hps.pop('noise_type'),
        noise_p=hps.pop('noise_p'),
        use_sched=hps.pop('use_sched'),
        num_env=num_env,
        exp_name=hps.pop('exp_name'),
    )
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        n_updates = 0

        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()

            if NSML:
                n_updates = int(info['update']['n_updates'])
                nsml_dict = {
                    k: np.float64(v)
                    for k, v in info['update'].items()
                    if isinstance(v, Number)
                }
                nsml.report(step=n_updates, **nsml_dict)

            counter += 1
        #if n_updates >= 40*1000: # 40K updates
        #    break
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()
def train(*, env_id, num_env, hps, num_timesteps, seed):
    venv = VecFrameStack(
        make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True
    # venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy,
              'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
                scope='pol',
                ob_space=ob_space,
                ac_space=ac_space,
                update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'),
                proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'),
                dynamics_bonus = hps.pop("dynamics_bonus")
            ),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
    )
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128*50)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()

    return agent
Beispiel #4
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    if "NoFrameskip" in env_id:
        env_factory = make_atari_env
    else:
        env_factory = make_non_atari_env

    venv = VecFrameStack(
        env_factory(
            env_id,
            num_env,
            seed,
            wrapper_kwargs=dict(),
            start_index=num_env * MPI.COMM_WORLD.Get_rank(),
            max_episode_steps=hps.pop("max_episode_steps"),
        ),
        hps.pop("frame_stack"),
    )
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True if env_id == "SolarisNoFrameskip-v4" else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop("gamma")
    policy = {
        "rnn": CnnGruPolicy,
        "cnn": CnnPolicy,
        'ffnn': GruPolicy
    }[hps.pop("policy")]
    agent = PpoAgent(
        scope="ppo",
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope="pol",
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                "update_ob_stats_independently_per_gpu"),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                "proportion_of_exp_used_for_predictor_update"),
            dynamics_bonus=hps.pop("dynamics_bonus"),
            meta_rl=hps['meta_rl']),
        gamma=gamma,
        gamma_ext=hps.pop("gamma_ext"),
        lam=hps.pop("lam"),
        nepochs=hps.pop("nepochs"),
        nminibatches=hps.pop("nminibatches"),
        lr=hps.pop("lr"),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop("max_grad_norm"),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop("update_ob_stats_every_step"),
        int_coeff=hps.pop("int_coeff"),
        ext_coeff=hps.pop("ext_coeff"),
        meta_rl=hps.pop('meta_rl'))
    agent.start_interaction([venv])
    if hps.pop("update_ob_stats_from_random_agent"):
        agent.collect_random_statistics(num_timesteps=128 * 50)
    assert len(hps) == 0, f"Unused hyperparameters: {list(hps.keys())}"

    counter = 0
    while True:
        info = agent.step()
        if info["update"]:
            logger.logkvs(info["update"])
            logger.dumpkvs()
            counter += 1
        if agent.I.stats["tcount"] > num_timesteps:
            break

    agent.stop_interaction()
Beispiel #5
0
def train(*, env_id, num_env, hps, num_timesteps, seed):

    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))

    # Size of states when stored in the memory.
    only_train_r = hps.pop('only_train_r')

    online_r_training = hps.pop('online_train_r') or only_train_r

    r_network_trainer = None
    save_path = hps.pop('save_path')
    r_network_weights_path = hps.pop('r_path')
    '''
    ec_type = 'none' # hps.pop('ec_type')

    venv = CuriosityEnvWrapperFrameStack(
        make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        vec_episodic_memory = None,
        observation_embedding_fn = None,
        exploration_reward = ec_type,
        exploration_reward_min_step = 0,
        nstack = hps.pop('frame_stack'),
        only_train_r = only_train_r
        )
    '''

    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')

    log_interval = hps.pop('log_interval')

    nminibatches = hps.pop('nminibatches')

    play = hps.pop('play')

    if play:
        nsteps = 1

    rnd_type = hps.pop('rnd_type')
    div_type = hps.pop('div_type')

    num_agents = hps.pop('num_agents')

    load_ram = hps.pop('load_ram')

    debug = hps.pop('debug')

    rnd_mask_prob = hps.pop('rnd_mask_prob')

    rnd_mask_type = hps.pop('rnd_mask_type')

    indep_rnd = hps.pop('indep_rnd')
    logger.info("indep_rnd:", indep_rnd)
    indep_policy = hps.pop('indep_policy')

    sd_type = hps.pop('sd_type')

    from_scratch = hps.pop('from_scratch')

    use_kl = hps.pop('use_kl')

    save_interval = 100

    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus"),
            num_agents=num_agents,
            rnd_type=rnd_type,
            div_type=div_type,
            indep_rnd=indep_rnd,
            indep_policy=indep_policy,
            sd_type=sd_type,
            rnd_mask_prob=rnd_mask_prob),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        gamma_div=hps.pop('gamma_div'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=nminibatches,
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=5 if debug else 128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        log_interval=log_interval,
        only_train_r=only_train_r,
        rnd_type=rnd_type,
        reset=hps.pop('reset'),
        dynamics_sample=hps.pop('dynamics_sample'),
        save_path=save_path,
        num_agents=num_agents,
        div_type=div_type,
        load_ram=load_ram,
        debug=debug,
        rnd_mask_prob=rnd_mask_prob,
        rnd_mask_type=rnd_mask_type,
        sd_type=sd_type,
        from_scratch=from_scratch,
        use_kl=use_kl,
        indep_rnd=indep_rnd)

    load_path = hps.pop('load_path')
    base_load_path = hps.pop('base_load_path')

    agent.start_interaction([venv])
    if load_path is not None:

        if play:
            agent.load(load_path)
        else:
            #agent.load(load_path)
            #agent.load_help_info(0, load_path)
            #agent.load_help_info(1, load_path)

            #load diversity agent
            #base_agent_idx = 1
            #logger.info("load base  agents weights from {}  agent {}".format(base_load_path, str(base_agent_idx)))
            #agent.load_agent(base_agent_idx, base_load_path)
            #agent.clone_baseline_agent(base_agent_idx)
            #agent.load_help_info(0, dagent_load_path)
            #agent.clone_agent(0)

            #load main agen1
            src_agent_idx = 1

            logger.info("load main agent weights from {} agent {}".format(
                load_path, str(src_agent_idx)))
            agent.load_agent(src_agent_idx, load_path)

            if indep_rnd == False:
                rnd_agent_idx = 1
            else:
                rnd_agent_idx = src_agent_idx
            #rnd_agent_idx = 0
            logger.info("load rnd weights from {} agent {}".format(
                load_path, str(rnd_agent_idx)))
            agent.load_rnd(rnd_agent_idx, load_path)
            agent.clone_agent(rnd_agent_idx,
                              rnd=True,
                              policy=False,
                              help_info=False)

            logger.info("load help info from {} agent {}".format(
                load_path, str(src_agent_idx)))
            agent.load_help_info(src_agent_idx, load_path)

            agent.clone_agent(src_agent_idx,
                              rnd=False,
                              policy=True,
                              help_info=True)

            #logger.info("load main agent weights from {} agent {}".format(load_path, str(2)))

            #load_path = '/data/xupeifrom7700_1000/seed1_log0.5_clip-0.5~0.5_3agent_hasint4_2divrew_-1~1/models'
            #agent.load_agent(1, load_path)

            #agent.clone_baseline_agent()
            #if sd_type =='sd':
            #    agent.load_sd("save_dir/models_sd_trained")

        #agent.initialize_discriminator()

    update_ob_stats_from_random_agent = hps.pop(
        'update_ob_stats_from_random_agent')
    if play == False:

        if load_path is not None:
            pass  #agent.collect_statistics_from_model()
        else:
            if update_ob_stats_from_random_agent and rnd_type == 'rnd':
                agent.collect_random_statistics(num_timesteps=128 *
                                                5 if debug else 128 * 50)
        assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

        #agent.collect_rnd_info(128*50)
        '''
        if sd_type=='sd':
            agent.train_sd(max_nepoch=300, max_neps=5)
            path = '{}_sd_trained'.format(save_path)
            logger.log("save model:",path)
            agent.save(path)

            return
            #agent.update_diverse_agent(max_nepoch=1000)
            #path = '{}_divupdated'.format(save_path)
            #logger.log("save model:",path)
            #agent.save(path)
        '''
        counter = 0
        while True:
            info = agent.step()

            n_updates = agent.I.stats["n_updates"]
            if info['update']:
                logger.logkvs(info['update'])
                logger.dumpkvs()
                counter += 1

            if info['update'] and save_path is not None and (
                    n_updates % save_interval == 0 or n_updates == 1):
                path = '{}_{}'.format(save_path, str(n_updates))
                logger.log("save model:", path)
                agent.save(path)
                agent.save_help_info(save_path, n_updates)

            if agent.I.stats['tcount'] > num_timesteps:
                path = '{}_{}'.format(save_path, str(n_updates))
                logger.log("save model:", path)
                agent.save(path)
                agent.save_help_info(save_path, n_updates)
                break
        agent.stop_interaction()
    else:
        '''
        check_point_rews_list_path ='{}_rewslist'.format(load_path)
        check_point_rnd_path ='{}_rnd'.format(load_path)
        oracle_rnd = oracle.OracleExplorationRewardForAllEpisodes()
        oracle_rnd.load(check_point_rnd_path)
        #print(oracle_rnd._collected_positions_writer)
        #print(oracle_rnd._collected_positions_reader)

        rews_list = load_rews_list(check_point_rews_list_path)
        print(rews_list)
        '''

        istate = agent.stochpol.initial_state(1)
        #ph_mean, ph_std = agent.stochpol.get_ph_mean_std()

        last_obs, prevrews, ec_rews, news, infos, ram_states, _ = agent.env_get(
            0)
        agent.I.step_count += 1

        flag = False
        show_cam = True

        last_xr = 0

        restore = None
        '''
        #path = 'ram_state_500_7room'
        #path='ram_state_400_6room'
        #path='ram_state_6700' 
        path='ram_state_7700_10room'
        f = open(path,'rb')
        restore = pickle.load(f)
        f.close()
        last_obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore,0)
        print(last_obs.shape)

        #path = 'ram_state_400_monitor_rews_6room'
        #path = 'ram_state_500_monitor_rews_7room'
        #path='ram_state_6700_monitor_rews'
        path='ram_state_7700_monitor_rews_10room'
        f = open(path,'rb')
        monitor_rews = pickle.load(f)
        f.close()
        
        agent.I.venvs[0].set_cur_monitor_rewards_by_idx(monitor_rews,0)
        '''

        agent_idx = np.asarray([0])
        sample_agent_prob = np.asarray([0.5])

        ph_mean = agent.stochpol.ob_rms_list[0].mean
        ph_std = agent.stochpol.ob_rms_list[0].var**0.5

        buf_ph_mean = np.zeros(
            ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]),
            np.float32)
        buf_ph_std = np.zeros(
            ([1, 1] + list(agent.stochpol.ob_space.shape[:2]) + [1]),
            np.float32)

        buf_ph_mean[0, 0] = ph_mean
        buf_ph_std[0, 0] = ph_std

        vpreds_ext_list = []

        ep_rews = np.zeros((1))
        divexp_flag = False
        step_count = 0
        stage_prob = True

        last_rew_ob = np.full_like(last_obs, 128)

        clusters = Clusters(1.0)

        #path = '{}_sd_rms'.format(load_path)
        #agent.I.sd_rms.load(path)

        while True:

            dict_obs = agent.stochpol.ensure_observation_is_dict(last_obs)

            #acs= np.random.randint(low=0, high=15, size=(1))
            acs, vpreds_int, vpreds_ext, nlps, istate, ent = agent.stochpol.call(
                dict_obs, news, istate, agent_idx[:, None])

            step_acs = acs
            t = ''
            #if show_cam==True:
            t = input("input:")
            if t != '':
                t = int(t)
                if t <= 17:
                    step_acs = [t]

            agent.env_step(0, step_acs)

            obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = agent.env_get(
                0)

            if news[0] and restore is not None:
                obs[0] = agent.I.venvs[0].restore_full_state_by_idx(restore, 0)
                agent.I.venvs[0].set_cur_monitor_rewards_by_idx(
                    monitor_rews, 0)

            ep_rews = ep_rews + prevrews

            print(ep_rews)

            last_rew_ob[prevrews > 0] = obs[prevrews > 0]

            room = infos[0]['position'][2]
            vpreds_ext_list.append([vpreds_ext, room])
            #print(monitor_rews[0])
            #print(len(monitor_rews[0]))
            #print(infos[0]['open_door_type'])

            stack_obs = np.concatenate([last_obs[:, None], obs[:, None]], 1)

            fd = {}

            fd[agent.stochpol.ph_ob[None]] = stack_obs

            fd.update({
                agent.stochpol.sep_ph_mean: buf_ph_mean,
                agent.stochpol.sep_ph_std: buf_ph_std
            })
            fd[agent.stochpol.ph_agent_idx] = agent_idx[:, None]
            fd[agent.stochpol.sample_agent_prob] = sample_agent_prob[:, None]

            fd[agent.stochpol.last_rew_ob] = last_rew_ob[:, None]
            fd[agent.stochpol.game_score] = ep_rews[:, None]

            fd[agent.stochpol.sd_ph_mean] = agent.I.sd_rms.mean
            fd[agent.stochpol.sd_ph_std] = agent.I.sd_rms.var**0.5

            div_prob = 0

            all_div_prob = tf_util.get_session().run(
                [agent.stochpol.all_div_prob], fd)
            '''
            if prevrews[0] > 0:
                clusters.update(rnd_em,room)
    
                num_clusters = len(clusters._cluster_list)
                for i in range(num_clusters):
                    print("{} {}".format(str(i),list(clusters._room_set[i])))
            '''
            print("vpreds_int: ", vpreds_int, "vpreds_ext:", vpreds_ext,
                  "ent:", ent, "all_div_prob:", all_div_prob, "room:", room,
                  "step_count:", step_count)

            #aaaa = np.asarray(vpreds_ext_list)
            #print(aaaa[-100:])
            '''
            if step_acs[0]==0:
                ram_state = ram_states[0]
                path='ram_state_7700_10room'
                f = open(path,'wb')
                pickle.dump(ram_state,f)
                f.close()

                path='ram_state_7700_monitor_rews_10room'
                f = open(path,'wb')
                pickle.dump(monitor_rews[0],f)
                f.close()
            '''
            '''
            if  restore is None:
                restore = ram_states[0]

            
            if np.random.rand() < 0.1:
                print("restore")
                obs = agent.I.venvs[0].restore_full_state_by_idx(restore,0)
                prevrews = None
                ec_rews = None
                news= True
                infos = {}
                ram_states = ram_states[0]

                #restore = ram_states[0]
            '''

            img = agent.I.venvs[0].render()

            last_obs = obs

            step_count = step_count + 1

            time.sleep(0.04)
def load_test(*, env_id, num_env, hps, num_timesteps, seed, fname):
    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs=dict(),
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1
    venv.record_obs = True
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        obs_save_flag=True)

    tf_util.load_state("saved_states/save1")
    agent.start_interaction([venv])
    counter = 0
    while True:
        info = agent.step()
        if agent.I.stats['epcount'] > 1:
            with open("obs_acs.pickle", 'wb') as f1:
                pickle.dump(agent.obs_rec, f1)
            break
Beispiel #7
0
def train(*,
          env_id,
          num_env,
          hps,
          num_timesteps,
          seed,
          test=False,
          e_greedy=0,
          **kwargs):
    # import pdb; pdb.set_trace()
    print('_________________________________________________________')
    pprint.pprint(f'hyperparams: {hps}', width=1)
    pprint.pprint(f'additional hyperparams: {kwargs}', width=1)
    print('_________________________________________________________')
    # TODO: this is just for debugging
    # tmp_env = make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(clip_rewards=kwargs['clip_rewards']))
    venv = VecFrameStack(
        make_atari_env(
            env_id,
            num_env,
            seed,
            wrapper_kwargs=dict(clip_rewards=kwargs['clip_rewards']),
            start_index=num_env * MPI.COMM_WORLD.Get_rank(),
            max_episode_steps=hps.pop('max_episode_steps'),
            action_space=kwargs['action_space']), hps.pop('frame_stack'))
    # venv.score_multiple = {'Mario': 500,
    #                        'MontezumaRevengeNoFrameskip-v4': 100,
    #                        'GravitarNoFrameskip-v4': 250,
    #                        'PrivateEyeNoFrameskip-v4': 500,
    #                        'SolarisNoFrameskip-v4': None,
    #                        'VentureNoFrameskip-v4': 200,
    #                        'PitfallNoFrameskip-v4': 100,
    #                        }[env_id]
    venv.score_multiple = 1  # TODO: understand what is score multiple
    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space
    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]
    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
    )
    saver = tf.train.Saver()
    if test:
        agent.restore_model(saver,
                            kwargs['load_dir'],
                            kwargs['exp_name'],
                            mtype=kwargs['load_mtype'])
    agent.start_interaction([venv])

    import time
    st = time.time()
    if hps.pop('update_ob_stats_from_random_agent'):
        if os.path.exists('./data/ob_rms.pkl'):
            with open('./data/ob_rms.pkl', 'rb') as handle:
                agent.stochpol.ob_rms.mean, agent.stochpol.ob_rms.var, agent.stochpol.ob_rms.count = pickle.load(
                    handle)

        else:
            ob_rms = agent.collect_random_statistics(
                num_timesteps=128 * 50)  # original number128*50
            with open('./data/ob_rms.pkl', 'wb') as handle:
                pickle.dump([ob_rms.mean, ob_rms.var, ob_rms.count],
                            handle,
                            protocol=2)
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())
    print(f'Time duration {time.time() - st}')

    if not test:
        counter = 1
        while True:
            info = agent.step()
            if info['update']:
                # import pdb; pdb.set_trace()
                logger.logkvs(info['update'])
                logger.dumpkvs()
            if agent.I.stats['tcount'] // 1e7 == counter:
                agent.save_model(saver,
                                 kwargs['save_dir'],
                                 kwargs['exp_name'],
                                 mtype=str(counter * 1e7))
                counter += 1
            if agent.I.stats['tcount'] > num_timesteps:
                break
    if test:
        for pkls in range(0, 10):
            print('collecting the', pkls, 'th pickle', ' ' * 20)
            all_rollout = agent.evaluate(env_id=env_id,
                                         seed=seed,
                                         episodes=1000,
                                         save_image=kwargs['save_image'],
                                         e_greedy=e_greedy)
            with open('./data/newworld21_' + str(pkls).zfill(1) + '.pkl',
                      'wb') as handle:
                pickle.dump(all_rollout, handle)
    agent.stop_interaction()
    if not test:
        agent.save_model(saver, kwargs['save_dir'], kwargs['exp_name'])
Beispiel #8
0
def train(*, env_id, num_env, hps, num_timesteps, seed):
    experiment = os.environ.get('EXPERIMENT_LVL')
    if experiment == 'ego':
        # for the ego experiment we needed a higher intrinsic coefficient
        hps['int_coeff'] = 3.0

    hyperparams = copy(hps)
    hyperparams.update({'seed': seed})
    logger.info("Hyperparameters:")
    logger.info(hyperparams)

    venv = VecFrameStack(
        make_atari_env(env_id,
                       num_env,
                       seed,
                       wrapper_kwargs={},
                       start_index=num_env * MPI.COMM_WORLD.Get_rank(),
                       max_episode_steps=hps.pop('max_episode_steps')),
        hps.pop('frame_stack'))
    venv.score_multiple = {
        'Mario': 500,
        'MontezumaRevengeNoFrameskip-v4': 1,
        'GravitarNoFrameskip-v4': 250,
        'PrivateEyeNoFrameskip-v4': 500,
        'SolarisNoFrameskip-v4': None,
        'VentureNoFrameskip-v4': 200,
        'PitfallNoFrameskip-v4': 100,
    }[env_id]

    venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
    ob_space = venv.observation_space
    ac_space = venv.action_space

    gamma = hps.pop('gamma')
    policy = {'rnn': CnnGruPolicy, 'cnn': CnnPolicy}[hps.pop('policy')]

    agent = PpoAgent(
        scope='ppo',
        ob_space=ob_space,
        ac_space=ac_space,
        stochpol_fn=functools.partial(
            policy,
            scope='pol',
            ob_space=ob_space,
            ac_space=ac_space,
            update_ob_stats_independently_per_gpu=hps.pop(
                'update_ob_stats_independently_per_gpu'),
            proportion_of_exp_used_for_predictor_update=hps.pop(
                'proportion_of_exp_used_for_predictor_update'),
            dynamics_bonus=hps.pop("dynamics_bonus")),
        gamma=gamma,
        gamma_ext=hps.pop('gamma_ext'),
        lam=hps.pop('lam'),
        nepochs=hps.pop('nepochs'),
        nminibatches=hps.pop('nminibatches'),
        lr=hps.pop('lr'),
        cliprange=0.1,
        nsteps=128,
        ent_coef=0.001,
        max_grad_norm=hps.pop('max_grad_norm'),
        use_news=hps.pop("use_news"),
        comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
        update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
        int_coeff=hps.pop('int_coeff'),
        ext_coeff=hps.pop('ext_coeff'),
        restore_model_path=hps.pop('restore_model_path'))
    agent.start_interaction([venv])
    if hps.pop('update_ob_stats_from_random_agent'):
        agent.collect_random_statistics(num_timesteps=128 * 50)

    save_model = hps.pop('save_model')
    assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())

    #profiler = cProfile.Profile()
    #profiler.enable()
    #tracemalloc.start()
    #prev_snap = tracemalloc.take_snapshot()
    counter = 0
    while True:
        info = agent.step()
        if info['update']:
            logger.logkvs(info['update'])
            logger.dumpkvs()
            counter += 1

            # if (counter % 10) == 0:
            #     snapshot = tracemalloc.take_snapshot()
            #     top_stats = snapshot.compare_to(prev_snap, 'lineno')
            #     for stat in top_stats[:10]:
            #         print(stat)
            #     prev_snap = snapshot
            #     profiler.dump_stats("profile_rnd")

            if (counter % 100) == 0 and save_model:
                agent.save_model(agent.I.step_count)

        if agent.I.stats['tcount'] > num_timesteps:
            break

    agent.stop_interaction()