def _register_from_config_yaml(self, id): if '.' in id: collection = id[:id.rfind('.')] elif self._default_collection is not None: collection = self._default_collection else: raise error.Error('Could not determine collection from ID, and no default_collection set: {}'.format(id)) try: collection_info = self._collections[collection] except KeyError: raise error.UnregisteredCollection('Could not load requested id={}. That belongs to the {} collection, but this runtime supports {}. Perhaps that was a typo, or you meant to use a different runtime?'.format(id, collection, ', '.join(self._collections.keys()))) srcdir = collection_info['srcdir'] default_task = collection_info['default_task'] path = os.path.abspath(os.path.join(srcdir, id, 'config.yml')) if not os.path.exists(path): raise error.UnregisteredEnv('Could not load spec for {}: no such file {}'.format(id, path), path) with open(path) as f: data = yaml.load(f) try: spec = data['spec'] except KeyError: reraise(suffix='while examining data from {}'.format(path)) constructor_name = spec.pop('type', default_task) constructor = utils.load(constructor_name) spec.pop('id', None) task = constructor(id=id, **spec) self._tasks[id] = task
def build(self, spec): spec = spec.copy() type = spec.pop('type', 'score') if type == 'average_score': return AverageScore(**spec) elif type == 'negative_score': return NegativeScore(**spec) elif type == 'score': return Score(**spec) else: raise error.Error('Invalid reward_type: %s', type)
def load(cls, src_dir, src, spec): if isinstance(spec, list): transitions = [cls.from_spec(src_dir, src, t) for t in spec] return TransitionList(transitions) type = spec.pop('type') if type == 'ClickTransition': return ClickTransition(src, **spec) elif type == 'KeyPressTransition': return KeyPressTransition(src, **spec) elif type == 'DragTransition': return DragTransition(src, **spec) else: raise error.Error('Bad transition type: {}'.format(spec['type']))
def set_reward_parser(self, env_info): self.env_id = env_info['env_id'] self._episode_id = env_info['episode_id'] # If in demo mode, let the demonstration code know what env is # active write_env_id(self.env_id) if self.env_id is None: return self.controlplane_spec = gym_controlplane.spec(self.env_id) self.spec = gym.spec(self.env_id) # This is quite slow (usually 100-200ms) so just be careful # about calling it too much. We also have some suspicions that # the scorer TF graph may leak memory but haven't needed to # investigate. self.reward_parser = self.controlplane_spec.build_reward_parser( load_vexpect=not self.no_vexpect, load_scorer=not self.no_scorer) # All the pixels needed for vexpect/scoring. subscription = self.reward_parser.subscription() if subscription is not None: subscription = [tuple(sub) for sub in subscription] metadata_encoding = self.spec.tags.get('metadata_encoding') if metadata_encoding is not None and subscription is not None: if metadata_encoding['type'] == 'qrcode': subscription += [ (metadata_encoding['x'], metadata_encoding['width'], metadata_encoding['y'], metadata_encoding['height']) ] else: raise error.Error( 'Unsupported metadata encoding type: {}'.format( metadata_encoding)) # Should fix this up and abstract # probe_key = self.spec.tags.get('action_probe') probe_key = 0x60 logger.info('Using metadata_encoding=%s probe_key=%s subscription=%s', metadata_encoding, probe_key, subscription) # Just subscribe to the parts of the screen we're going to care about self.env.unwrapped.diagnostics.update( metadata_encoding=metadata_encoding, probe_key=probe_key, ) self.env.unwrapped.vnc_session.update( name=self.env.unwrapped.connection_names[0], # hack subscription=subscription or [], )
def load(cls, src_dir, state_name, spec): spec = spec.copy() type = spec.pop('type') try: if type == 'ImageMatchState': return ImageMatchState(src_dir=src_dir, state_name=state_name, **spec) if type == 'MaskState': return MaskState.load(src_dir=src_dir, state_name=state_name, **spec) else: raise error.Error('Bad state type: {}'.format(type)) except error.Error: raise except: reraise(suffix='(while applying: state_name={} spec={})'.format( state_name, spec))
def _scorer(self, scorer): if scorer['type'] == 'DefaultScorer': loaded = reward.DefaultScorer( digits_path='{}/digits'.format(self.id), crop_coords=scorer['crop_coords'], digit_color=scorer['digit_color'], color_tol=scorer['color_tol'], min_overlap=scorer['min_overlap'], min_spacing=scorer['min_spacing'], max_spacing=scorer.get('max_spacing')) elif scorer['type'] == 'OCRScorerV0': loaded = reward.OCRScorerV0( model_path=scorer['model_path'], crop_coords=scorer['crop_coords'], prob_threshold=scorer['prob_threshold'], median_filter_size=scorer.get('median_filter_size'), max_delta=scorer.get('max_delta', None), ) else: raise error.Error('Unsupported scorer type: {}'.format(scorer['type'])) logger.info('Loaded scorer: %s', loaded) return loaded
def _setup(self): if not self.controlplane_spec: return elif not os.path.exists(self.controlplane_spec.vexpect_path): # TODO: DRY this up logger.info( '[%s] Skipping vexpect initialization since no macro present', utils.thread_name()) return elif self.no_vexpect: logger.info('[%s] Skipping vexpect initialization as configured', utils.thread_name()) return cmd = [ os.path.abspath( os.path.join(os.path.dirname(__file__), '../bin/play_vexpect')), '-e', self.controlplane_spec.id, '-r', self.vnc_address, '-d' ] logger.info('[%s] Running command: %s', utils.thread_name(), utils.pretty_command(cmd)) proc = subprocess.Popen(cmd) manual_subprocess_cleanup[proc.pid] = proc proc.communicate() del manual_subprocess_cleanup[proc.pid] if proc.returncode == 0: return elif proc.returncode == 10: logger.info( '[%s] RESET CAUSE: VExpect failed with returncode 10, which means it timed out internally. Going to trigger a reset.', utils.thread_name()) self.trigger_reset() return 'fail' else: raise error.Error('Bad returncode {} from {}'.format( proc.returncode, utils.pretty_command(cmd)))
def imread(path): import cv2 if not os.path.exists(path): raise error.Error('Image path does not exist: {}'.format(path)) return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)