Ejemplo n.º 1
0
    def _register_from_config_yaml(self, id):
        if '.' in id:
            collection = id[:id.rfind('.')]
        elif self._default_collection is not None:
            collection = self._default_collection
        else:
            raise error.Error('Could not determine collection from ID, and no default_collection set: {}'.format(id))

        try:
            collection_info = self._collections[collection]
        except KeyError:
            raise error.UnregisteredCollection('Could not load requested id={}. That belongs to the {} collection, but this runtime supports {}. Perhaps that was a typo, or you meant to use a different runtime?'.format(id, collection, ', '.join(self._collections.keys())))
        srcdir = collection_info['srcdir']
        default_task = collection_info['default_task']

        path = os.path.abspath(os.path.join(srcdir, id, 'config.yml'))
        if not os.path.exists(path):
            raise error.UnregisteredEnv('Could not load spec for {}: no such file {}'.format(id, path), path)

        with open(path) as f:
            data = yaml.load(f)
        try:
            spec = data['spec']
        except KeyError:
            reraise(suffix='while examining data from {}'.format(path))
        constructor_name = spec.pop('type', default_task)
        constructor = utils.load(constructor_name)
        spec.pop('id', None)

        task = constructor(id=id, **spec)
        self._tasks[id] = task
Ejemplo n.º 2
0
 def build(self, spec):
     spec = spec.copy()
     type = spec.pop('type', 'score')
     if type == 'average_score':
         return AverageScore(**spec)
     elif type == 'negative_score':
         return NegativeScore(**spec)
     elif type == 'score':
         return Score(**spec)
     else:
         raise error.Error('Invalid reward_type: %s', type)
Ejemplo n.º 3
0
 def load(cls, src_dir, src, spec):
     if isinstance(spec, list):
         transitions = [cls.from_spec(src_dir, src, t) for t in spec]
         return TransitionList(transitions)
     type = spec.pop('type')
     if type == 'ClickTransition':
         return ClickTransition(src, **spec)
     elif type == 'KeyPressTransition':
         return KeyPressTransition(src, **spec)
     elif type == 'DragTransition':
         return DragTransition(src, **spec)
     else:
         raise error.Error('Bad transition type: {}'.format(spec['type']))
Ejemplo n.º 4
0
    def set_reward_parser(self, env_info):
        self.env_id = env_info['env_id']
        self._episode_id = env_info['episode_id']

        # If in demo mode, let the demonstration code know what env is
        # active
        write_env_id(self.env_id)

        if self.env_id is None:
            return
        self.controlplane_spec = gym_controlplane.spec(self.env_id)
        self.spec = gym.spec(self.env_id)

        # This is quite slow (usually 100-200ms) so just be careful
        # about calling it too much. We also have some suspicions that
        # the scorer TF graph may leak memory but haven't needed to
        # investigate.
        self.reward_parser = self.controlplane_spec.build_reward_parser(
            load_vexpect=not self.no_vexpect, load_scorer=not self.no_scorer)

        # All the pixels needed for vexpect/scoring.
        subscription = self.reward_parser.subscription()
        if subscription is not None:
            subscription = [tuple(sub) for sub in subscription]

        metadata_encoding = self.spec.tags.get('metadata_encoding')
        if metadata_encoding is not None and subscription is not None:
            if metadata_encoding['type'] == 'qrcode':
                subscription += [
                    (metadata_encoding['x'], metadata_encoding['width'],
                     metadata_encoding['y'], metadata_encoding['height'])
                ]
            else:
                raise error.Error(
                    'Unsupported metadata encoding type: {}'.format(
                        metadata_encoding))
        # Should fix this up and abstract
        # probe_key = self.spec.tags.get('action_probe')
        probe_key = 0x60

        logger.info('Using metadata_encoding=%s probe_key=%s subscription=%s',
                    metadata_encoding, probe_key, subscription)
        # Just subscribe to the parts of the screen we're going to care about
        self.env.unwrapped.diagnostics.update(
            metadata_encoding=metadata_encoding,
            probe_key=probe_key,
        )
        self.env.unwrapped.vnc_session.update(
            name=self.env.unwrapped.connection_names[0],  # hack
            subscription=subscription or [],
        )
Ejemplo n.º 5
0
 def load(cls, src_dir, state_name, spec):
     spec = spec.copy()
     type = spec.pop('type')
     try:
         if type == 'ImageMatchState':
             return ImageMatchState(src_dir=src_dir,
                                    state_name=state_name,
                                    **spec)
         if type == 'MaskState':
             return MaskState.load(src_dir=src_dir,
                                   state_name=state_name,
                                   **spec)
         else:
             raise error.Error('Bad state type: {}'.format(type))
     except error.Error:
         raise
     except:
         reraise(suffix='(while applying: state_name={} spec={})'.format(
             state_name, spec))
Ejemplo n.º 6
0
 def _scorer(self, scorer):
     if scorer['type'] == 'DefaultScorer':
         loaded = reward.DefaultScorer(
             digits_path='{}/digits'.format(self.id),
             crop_coords=scorer['crop_coords'],
             digit_color=scorer['digit_color'],
             color_tol=scorer['color_tol'],
             min_overlap=scorer['min_overlap'],
             min_spacing=scorer['min_spacing'],
             max_spacing=scorer.get('max_spacing'))
     elif scorer['type'] == 'OCRScorerV0':
         loaded = reward.OCRScorerV0(
             model_path=scorer['model_path'],
             crop_coords=scorer['crop_coords'],
             prob_threshold=scorer['prob_threshold'],
             median_filter_size=scorer.get('median_filter_size'),
             max_delta=scorer.get('max_delta', None),
             )
     else:
         raise error.Error('Unsupported scorer type: {}'.format(scorer['type']))
     logger.info('Loaded scorer: %s', loaded)
     return loaded
Ejemplo n.º 7
0
    def _setup(self):
        if not self.controlplane_spec:
            return
        elif not os.path.exists(self.controlplane_spec.vexpect_path):
            # TODO: DRY this up
            logger.info(
                '[%s] Skipping vexpect initialization since no macro present',
                utils.thread_name())
            return
        elif self.no_vexpect:
            logger.info('[%s] Skipping vexpect initialization as configured',
                        utils.thread_name())
            return
        cmd = [
            os.path.abspath(
                os.path.join(os.path.dirname(__file__),
                             '../bin/play_vexpect')), '-e',
            self.controlplane_spec.id, '-r', self.vnc_address, '-d'
        ]
        logger.info('[%s] Running command: %s', utils.thread_name(),
                    utils.pretty_command(cmd))
        proc = subprocess.Popen(cmd)
        manual_subprocess_cleanup[proc.pid] = proc
        proc.communicate()
        del manual_subprocess_cleanup[proc.pid]

        if proc.returncode == 0:
            return
        elif proc.returncode == 10:
            logger.info(
                '[%s] RESET CAUSE: VExpect failed with returncode 10, which means it timed out internally. Going to trigger a reset.',
                utils.thread_name())
            self.trigger_reset()
            return 'fail'
        else:
            raise error.Error('Bad returncode {} from {}'.format(
                proc.returncode, utils.pretty_command(cmd)))
Ejemplo n.º 8
0
def imread(path):
    import cv2
    if not os.path.exists(path):
        raise error.Error('Image path does not exist: {}'.format(path))
    return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)