Ejemplo n.º 1
0
def save_image(figure, filepath=None):
    if os.environ['PY_ENV'] == 'test':
        return
    if filepath is None:
        filepath = f'{PLOT_FILEDIR}/{ps.get(figure, "layout.title")}.png'
    filepath = util.smart_path(filepath)
    dirname, filename = os.path.split(filepath)
    try:
        cmd = f'orca graph -o {filename} \'{json.dumps(figure)}\''
        if 'linux' in sys.platform:
            cmd = 'xvfb-run -a -s "-screen 0 1400x900x24" -- ' + cmd
        proc = Popen(cmd,
                     cwd=dirname,
                     shell=True,
                     stderr=DEVNULL,
                     stdout=DEVNULL)
        try:
            outs, errs = proc.communicate(timeout=20)
        except TimeoutExpired:
            proc.kill()
            outs, errs = proc.communicate()
        logger.info(f'Graph saved to {dirname}/{filename}')
    except Exception as e:
        logger.exception(
            'Please install orca for plotly and run retro-analysis to generate graphs.'
        )
Ejemplo n.º 2
0
def test_logger(test_str):
    logger.critical(test_str)
    logger.debug(test_str)
    logger.error(test_str)
    logger.exception(test_str)
    logger.info(test_str)
    logger.warning(test_str)
Ejemplo n.º 3
0
def test_logger(test_multiline_str):
    logger.critical(test_multiline_str)
    logger.debug(test_multiline_str)
    logger.error(test_multiline_str)
    logger.exception(test_multiline_str)
    logger.info(test_multiline_str)
    logger.warn(test_multiline_str)
Ejemplo n.º 4
0
def calc_trial_fitness_df(trial):
    '''
    Calculate the trial fitness df by aggregating from the collected session_data_dict (session_fitness_df's).
    Adds a consistency dimension to fitness vector.
    '''
    trial_fitness_data = {}
    try:
        all_session_fitness_df = pd.concat(list(trial.session_data_dict.values()))
    except ValueError as e:
        logger.exception('Sessions failed, no data to analyze. Check stack trace above')
    for aeb in util.get_df_aeb_list(all_session_fitness_df):
        aeb_fitness_df = all_session_fitness_df.loc[:, aeb]
        aeb_fitness_sr = aeb_fitness_df.mean()
        consistency = calc_consistency(aeb_fitness_df)
        aeb_fitness_sr = aeb_fitness_sr.append(pd.Series({'consistency': consistency}))
        aeb_fitness_df = pd.DataFrame([aeb_fitness_sr], index=[trial.index])
        aeb_fitness_df = aeb_fitness_df.reindex(FITNESS_COLS, axis=1)
        trial_fitness_data[aeb] = aeb_fitness_df
    # form multi_index df, then take mean across all bodies
    trial_fitness_df = pd.concat(trial_fitness_data, axis=1)
    mean_fitness_df = calc_mean_fitness(trial_fitness_df)
    trial_fitness_df = mean_fitness_df
    trial_fitness = calc_fitness(mean_fitness_df)
    logger.info(f'Trial mean fitness: {trial_fitness}\n{mean_fitness_df}')
    return trial_fitness_df
Ejemplo n.º 5
0
def save_image(figure, filepath=None):
    if os.environ['PY_ENV'] == 'test':
        return
    if filepath is None:
        filepath = f'{PLOT_FILEDIR}/{ps.get(figure, "layout.title")}.png'
    filepath = util.smart_path(filepath)
    try:
        pio.write_image(figure, filepath)
        logger.info(f'Graph saved to {filepath}')
    except Exception as e:
        logger.exception(
            'Failed to generate graph. Fix the issue and run retro-analysis to generate graphs.')
Ejemplo n.º 6
0
def save_image(figure, filepath=None):
    if os.environ['PY_ENV'] == 'test':
        return
    if filepath is None:
        filepath = f'{PLOT_FILEDIR}/{ps.get(figure, "layout.title")}.png'
    filepath = util.smart_path(filepath)
    dirname, filename = os.path.split(filepath)
    try:
        Popen(['orca', 'graph', '--verbose', '-o', filename, json.dumps(figure)], cwd=dirname)
    except Exception as e:
        logger.exception(
            'Please install orca for plotly and run retro-analysis to generate graphs.')
Ejemplo n.º 7
0
def get_ray_results(pending_ids, ray_id_to_config):
    '''Helper to wait and get ray results into a new trial_data_dict, or handle ray error'''
    trial_data_dict = {}
    for _t in range(len(pending_ids)):
        ready_ids, pending_ids = ray.wait(pending_ids, num_returns=1)
        ready_id = ready_ids[0]
        try:
            trial_data = ray.get(ready_id)
            trial_index = trial_data.pop('trial_index')
            trial_data_dict[trial_index] = trial_data
        except:
            logger.exception(f'Trial failed: {ray_id_to_config[ready_id]}')
    return trial_data_dict
Ejemplo n.º 8
0
def get_ray_results(pending_ids, ray_id_to_config):
    '''Helper to wait and get ray results into a new trial_data_dict, or handle ray error'''
    trial_data_dict = {}
    for _t in range(len(pending_ids)):
        ready_ids, pending_ids = ray.wait(pending_ids, num_returns=1)
        ready_id = ready_ids[0]
        try:
            trial_data = ray.get(ready_id)
            trial_index = trial_data.pop('trial_index')
            trial_data_dict[trial_index] = trial_data
        except:
            logger.exception(f'Trial failed: {ray_id_to_config[ready_id]}')
    return trial_data_dict
Ejemplo n.º 9
0
def check_all():
    '''Check all spec files, all specs.'''
    spec_files = _.filter_(os.listdir(SPEC_DIR), lambda f: f.endswith('.json'))
    for spec_file in spec_files:
        spec_dict = util.read(f'{SPEC_DIR}/{spec_file}')
        for spec_name, spec in spec_dict.items():
            try:
                spec['name'] = spec_name
                check(spec)
            except Exception as e:
                logger.exception(f'spec_file {spec_file} fails spec check')
                raise e
    logger.info(f'Checked all specs from: {_.join(spec_files, ",")}')
    return True
Ejemplo n.º 10
0
def check_all():
    '''Check all spec files, all specs.'''
    spec_files = ps.filter_(os.listdir(SPEC_DIR), lambda f: f.endswith('.json') and not f.startswith('_'))
    for spec_file in spec_files:
        spec_dict = util.read(f'{SPEC_DIR}/{spec_file}')
        for spec_name, spec in spec_dict.items():
            try:
                spec['name'] = spec_name
                spec['git_SHA'] = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()
                check(spec)
            except Exception as e:
                logger.exception(f'spec_file {spec_file} fails spec check')
                raise e
    logger.info(f'Checked all specs from: {ps.join(spec_files, ",")}')
    return True
Ejemplo n.º 11
0
def check_all():
    '''Check all spec files, all specs.'''
    spec_files = ps.filter_(os.listdir(SPEC_DIR), lambda f: f.endswith('.json') and not f.startswith('_'))
    for spec_file in spec_files:
        spec_dict = util.read(f'{SPEC_DIR}/{spec_file}')
        for spec_name, spec in spec_dict.items():
            # fill-in info at runtime
            spec['name'] = spec_name
            spec = extend_meta_spec(spec)
            try:
                check(spec)
            except Exception as e:
                logger.exception(f'spec_file {spec_file} fails spec check')
                raise e
    logger.info(f'Checked all specs from: {ps.join(spec_files, ",")}')
    return True
Ejemplo n.º 12
0
def save_image(figure, filepath=None):
    if os.environ['PY_ENV'] == 'test':
        return
    if filepath is None:
        filepath = f'{PLOT_FILEDIR}/{ps.get(figure, "layout.title")}.png'
    filepath = util.smart_path(filepath)
    dirname, filename = os.path.split(filepath)
    try:
        cmd = f'orca graph -o {filename} \'{json.dumps(figure)}\''
        if 'linux' in sys.platform:
            cmd = 'xvfb-run -a -s "-screen 0 1400x900x24" -- ' + cmd
        Popen(cmd, cwd=dirname, shell=True, stderr=DEVNULL, stdout=DEVNULL)
        logger.info(f'Graph saved to {dirname}/{filename}')
    except Exception as e:
        logger.exception(
            'Please install orca for plotly and run retro-analysis to generate graphs.')
Ejemplo n.º 13
0
def check(spec):
    '''Check a single spec for validity'''
    try:
        spec_name = spec.get('name')
        assert set(spec.keys()) >= set(SPEC_FORMAT.keys()), f'Spec needs to follow spec.SPEC_FORMAT. Given \n {spec_name}: {util.to_json(spec)}'
        for agent_spec in spec['agent']:
            check_comp_spec(agent_spec, SPEC_FORMAT['agent'][0])
        for env_spec in spec['env']:
            check_comp_spec(env_spec, SPEC_FORMAT['env'][0])
        check_comp_spec(spec['body'], SPEC_FORMAT['body'])
        check_comp_spec(spec['meta'], SPEC_FORMAT['meta'])
        check_body_spec(spec)
    except Exception as e:
        logger.exception(f'spec {spec_name} fails spec check')
        raise e
    return True
Ejemplo n.º 14
0
def check(spec):
    '''Check a single spec for validity'''
    try:
        spec_name = spec.get('name')
        assert set(spec.keys()) >= set(SPEC_FORMAT.keys()), f'Spec needs to follow spec.SPEC_FORMAT. Given \n {spec_name}: {util.to_json(spec)}'
        for agent_spec in spec['agent']:
            check_comp_spec(agent_spec, SPEC_FORMAT['agent'][0])
        for env_spec in spec['env']:
            check_comp_spec(env_spec, SPEC_FORMAT['env'][0])
        check_comp_spec(spec['body'], SPEC_FORMAT['body'])
        check_comp_spec(spec['meta'], SPEC_FORMAT['meta'])
        # check_body_spec(spec)
        check_compatibility(spec)
    except Exception as e:
        logger.exception(f'spec {spec_name} fails spec check')
        raise e
    return True
Ejemplo n.º 15
0
def try_register_env(spec):
    '''Try to additional environments for OpenAI gym.'''
    try:
        env_name = spec['env'][0]['name']
        if env_name == 'vizdoom-v0':
            assert 'cfg_name' in spec['env'][0].keys(
            ), 'Environment config name must be defined for vizdoom.'
            cfg_name = spec['env'][0]['cfg_name']
            register(id=env_name,
                     entry_point='slm_lab.env.vizdoom.vizdoom_env:VizDoomEnv',
                     kwargs={'cfg_name': cfg_name})
        elif env_name.startswith('Unity'):
            # NOTE: do not specify max_episode_steps, will cause shape inconsistency in done
            register(id=env_name,
                     entry_point='slm_lab.env.unity:GymUnityEnv',
                     kwargs={'name': env_name})
    except Exception as e:
        logger.exception(e)
Ejemplo n.º 16
0
def check_all():
    '''Check all spec files, all specs.'''
    spec_files = ps.filter_(
        os.listdir(SPEC_DIR),
        lambda f: f.endswith('.json') and not f.startswith('_'))
    for spec_file in spec_files:
        spec_dict = util.read(f'{SPEC_DIR}/{spec_file}')
        for spec_name, spec in spec_dict.items():
            try:
                spec['name'] = spec_name
                spec['git_SHA'] = subprocess.check_output(
                    ['git', 'rev-parse', 'HEAD']).decode().strip()
                check(spec)
            except Exception as e:
                logger.exception(f'spec_file {spec_file} fails spec check')
                raise e
    logger.info(f'Checked all specs from: {ps.join(spec_files, ",")}')
    return True
Ejemplo n.º 17
0
def subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes,
                   obs_dtypes, keys):
    '''
    Control a single environment instance using IPC and shared memory. Used by ShmemVecEnv.
    '''
    def _write_obs(maybe_dict_obs):
        flatdict = obs_to_dict(maybe_dict_obs)
        for k in keys:
            dst = obs_bufs[k].get_obj()
            dst_np = np.frombuffer(dst,
                                   dtype=obs_dtypes[k]).reshape(obs_shapes[k])
            np.copyto(dst_np, flatdict[k])

    env = env_fn_wrapper.x()
    parent_pipe.close()
    try:
        while True:
            cmd, data = pipe.recv()
            if cmd == 'reset':
                pipe.send(_write_obs(env.reset()))
            elif cmd == 'step':
                obs, reward, done, info = env.step(data)
                if done:
                    obs = env.reset()
                pipe.send((_write_obs(obs), reward, done, info))
            elif cmd == 'render':
                pipe.send(env.render(mode='rgb_array'))
            elif cmd == 'close':
                pipe.send(None)
                break
            else:
                raise RuntimeError(f'Got unrecognized cmd {cmd}')
    except KeyboardInterrupt:
        logger.exception('ShmemVecEnv worker: got KeyboardInterrupt')
    finally:
        env.close()