def __init__(self, rom_path, seed, step_limit=None): self.rom_path = rom_path self.bindings = load_bindings(rom_path) self.act_gen = TemplateActionGenerator(self.bindings) self.seed = seed self.steps = 0 self.step_limit = step_limit self.env = None self.conn = None self.vocab_rev = None
def __init__(self, params): configure_logger(params['output_dir']) log('Parameters {}'.format(params)) self.params = params self.binding = load_bindings(params['rom_file_path']) self.max_word_length = self.binding['max_word_length'] self.sp = spm.SentencePieceProcessor() self.sp.Load(params['spm_file']) kg_env = KGA2CEnv(params['rom_file_path'], params['seed'], self.sp, params['tsv_file'], step_limit=params['reset_steps'], stuck_steps=params['stuck_steps'], gat=params['gat']) self.vec_env = VecEnv(params['batch_size'], kg_env, params['openie_path']) self.template_generator = TemplateActionGenerator(self.binding) env = FrotzEnv(params['rom_file_path']) self.vocab_act, self.vocab_act_rev = load_vocab(env) self.model = KGA2C(params, self.template_generator.templates, self.max_word_length, self.vocab_act, self.vocab_act_rev, len(self.sp), gat=self.params['gat']).cuda() self.batch_size = params['batch_size'] if params['preload_weights']: self.model = torch.load(self.params['preload_weights'])['model'] self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr']) self.loss_fn1 = nn.BCELoss() self.loss_fn2 = nn.BCEWithLogitsLoss() self.loss_fn3 = nn.MSELoss()
def test_valid_action_identification(): rom_path = pjoin(DATA_PATH, '905.z5') env = jericho.FrotzEnv(rom_path) bindings = jericho.load_bindings(rom_path) act_gen = TemplateActionGenerator(bindings) obs, info = env.reset() # interactive_objs = [obj[0] for obj in env.identify_interactive_objects(use_object_tree=True)] interactive_objs = ['phone', 'keys', 'wallet'] candidate_actions = act_gen.generate_actions(interactive_objs) valid = env.find_valid_actions(candidate_actions) assert 'take wallet' in valid assert 'open wallet' in valid assert 'take keys' in valid assert 'get up' in valid assert 'take phone' in valid
def create(self): self.env = FrotzEnv(self.rom_path, self.seed) self.bindings = self.env.bindings self.act_gen = TemplateActionGenerator(self.bindings) self.vocab_rev = load_vocab_rev(self.env) self.conn = redis.Redis(host='localhost', port=6379, db=0) self.conn.flushdb()
def __init__(self, params): torch.manual_seed(params['seed']) np.random.seed(params['seed']) random.seed(params['seed']) configure_logger(params['output_dir']) log('Parameters {}'.format(params)) self.params = params self.chkpt_path = os.path.dirname(self.params['checkpoint_path']) if not os.path.exists(self.chkpt_path): os.mkdir(self.chkpt_path) self.binding = load_bindings(params['rom_file_path']) self.max_word_length = self.binding['max_word_length'] self.sp = spm.SentencePieceProcessor() self.sp.Load(params['spm_file']) #askbert_args = {'input_text': '', 'length': 10, 'batch_size': 1, 'temperature': 1, 'model_name': '117M', # 'seed': 0, 'nsamples': 10, 'cutoffs': "6.5 -7 -5", 'write_sfdp': False, 'random': False} #self.extraction = kgextraction.World([], [], [], askbert_args) self.askbert = params['extraction'] kg_env = QBERTEnv(params['rom_file_path'], params['seed'], self.sp, params['tsv_file'], params['attr_file'], step_limit=params['reset_steps'], stuck_steps=params['stuck_steps'], gat=params['gat'], askbert=self.askbert, clear_kg=params['clear_kg_on_reset']) self.vec_env = VecEnv(params['batch_size'], kg_env, params['openie_path'], params['redis_db_path'], params['buffer_size'], params['extraction'], params['training_type'], params['clear_kg_on_reset']) self.template_generator = TemplateActionGenerator(self.binding) env = FrotzEnv(params['rom_file_path']) self.max_game_score = env.get_max_score() self.cur_reload_state = env.get_state() self.vocab_act, self.vocab_act_rev = load_vocab(env) self.model = QBERT(params, self.template_generator.templates, self.max_word_length, self.vocab_act, self.vocab_act_rev, len(self.sp), gat=self.params['gat']).cuda() self.batch_size = params['batch_size'] if params['preload_weights']: self.model = torch.load(self.params['preload_weights'])['model'] self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr']) self.loss_fn1 = nn.BCELoss() self.loss_fn2 = nn.BCEWithLogitsLoss() self.loss_fn3 = nn.MSELoss() self.chained_logger = params['chained_logger'] self.total_steps = 0
def __init__(self, params, args): configure_logger(params['output_dir']) log('Parameters {}'.format(params)) self.params = params self.binding = load_bindings(params['rom_file_path']) self.max_word_length = self.binding['max_word_length'] self.sp = spm.SentencePieceProcessor() self.sp.Load(params['spm_file']) self.use_cs = self.params['use_cs'] if (self.use_cs == True): print("Using COMET") self.kg_extract = CometHelper(args) kg_env = KGA2CEnv(params['rom_file_path'], params['seed'], self.sp, params['tsv_file'], step_limit=params['reset_steps'], stuck_steps=params['stuck_steps'], gat=params['gat']) self.vec_env = VecEnv(params['batch_size'], kg_env, params['openie_path']) self.template_generator = TemplateActionGenerator(self.binding) env = FrotzEnv(params['rom_file_path']) self.vocab_act, self.vocab_act_rev = load_vocab(env) torch.cuda.set_device(int(self.params['device_a2c'])) # self.model = KGA2C(params, self.template_generator.templates, self.max_word_length, # self.vocab_act, self.vocab_act_rev, len(self.sp), a2c_device=(int(self.params['device_a2c'])), # bert_device =int(self.params['device_bert']), # gat=self.params['gat']) self.model = KGA2C(params, self.template_generator.templates, self.max_word_length, self.vocab_act, self.vocab_act_rev, len(self.sp), a2c_device=(int(self.params['device_a2c'])), gat=self.params['gat']) # print(torch.cuda.current_device()) self.batch_size = params['batch_size'] if params['preload_weights']: self.model = torch.load(self.params['preload_weights'])['model'] self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr']) self.loss_fn1 = nn.BCELoss() self.loss_fn2 = nn.BCEWithLogitsLoss() self.loss_fn3 = nn.MSELoss() self.args = args
def __init__(self, rom_path, seed, step_limit=None, env_num=8): self.rom_path = rom_path self.seed = seed self.step_limit = step_limit self.bindings = load_bindings(rom_path) self.seed = self.bindings['seed'] # some additional templates here, could make it game specific # Note: changes to it may cause the template id inconsistent self.additional_templates = ['land'] self.act_gen = TemplateActionGenerator(self.bindings) self.act_gen.templates = list( set(self.act_gen.templates + self.additional_templates)) self.act_gen.templates.sort() self.id2template = None self.template2id = None self._compute_template() self.env = FrotzEnv(self.rom_path, self.seed) self.env_num = env_num self.ps = None self.envs = None self.remotes = None self.work_remotes = None self.parallel = False self._init_parallel_workers() self.steps = 0 self.max_word_len = self.bindings['max_word_length'] self.word2id = None self.id2word = None self.noun_words = None self._compute_vocab_act() self.state2valid_acts = {}
def __init__(self, args): configure_logger(args.output_dir) log(args) self.args = args self.log_freq = args.log_freq self.update_freq = args.update_freq_td self.update_freq_tar = args.update_freq_tar self.filename = 'ptdqn' + args.rom_path + str(args.run_number) wandb.init(project="my-project", name=self.filename) self.sp = spm.SentencePieceProcessor() self.sp.Load(args.spm_path) self.binding = jericho.load_bindings(args.rom_path) self.vocab_act, self.vocab_act_rev = self.load_vocab_act(args.rom_path) vocab_size = len(self.sp) vocab_size_act = len(self.vocab_act.keys()) self.template_generator = TemplateActionGenerator(self.binding) self.template_size = len(self.template_generator.templates) if args.replay_buffer_type == 'priority': self.replay_buffer = PriorityReplayBuffer( int(args.replay_buffer_size)) elif args.replay_buffer_type == 'standard': self.replay_buffer = ReplayBuffer(int(args.replay_buffer_size)) self.action_size = self.template_size self.action_parameter_size = vocab_size_act self.model = TDQN(args, self.action_size, self.action_parameter_size, self.template_size, vocab_size, vocab_size_act).cuda() self.target_model = TDQN(args, self.action_size, self.action_parameter_size, self.template_size, vocab_size, vocab_size_act).cuda() self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr) self.num_steps = args.steps self.batch_size = args.batch_size self.gamma = args.gamma self.rho = args.rho self.vocab_size_act = vocab_size_act self.bce_loss = nn.BCELoss()
def __init__(self, params): print("----- Initiating ----- ") print("----- step 1 configure logger") configure_logger(params['output_dir']) log('Parameters {}'.format(params)) self.params = params print("----- step 2 load pre-collected things") self.binding = load_bindings(params['rom_file_path']) self.max_word_length = self.binding['max_word_length'] self.sp = spm.SentencePieceProcessor() self.sp.Load(params['spm_file']) print("----- step 3 build KGA2CEnv") kg_env = KGA2CEnv(params['rom_file_path'], params['seed'], self.sp, params['tsv_file'], step_limit=params['reset_steps'], stuck_steps=params['stuck_steps'], gat=params['gat']) self.vec_env = VecEnv(params['batch_size'], kg_env, params['openie_path']) print("----- step 4 build FrotzEnv and templace generator") env = FrotzEnv(params['rom_file_path']) self.vocab_act, self.vocab_act_rev = load_vocab(env) self.template_generator = TemplateActionGenerator(self.binding) print("----- step 5 build kga2c model") self.model = KGA2C(params, self.template_generator.templates, self.max_word_length, self.vocab_act, self.vocab_act_rev, len(self.sp), gat=self.params['gat']).cuda() if params['preload_weights']: print("load pretrained") self.model = torch.load(self.params['preload_weights'])['model'] else: print("train from scratch") print("----- step 6 set training parameters") self.batch_size = params['batch_size'] self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr']) self.loss_fn1 = nn.BCELoss() self.loss_fn2 = nn.BCEWithLogitsLoss() self.loss_fn3 = nn.MSELoss() print("----- Init finished! ----- ")
def __init__(self, args): self.args = args self.log_freq = args.log_freq self.update_freq = args.update_freq_td self.update_freq_tar = args.update_freq_tar self.filename = 'random' + args.rom_path + str(args.run_number) wandb.init(project="my-project", name=self.filename) self.sp = spm.SentencePieceProcessor() self.sp.Load(args.spm_path) self.binding = jericho.load_bindings(args.rom_path) self.vocab_act, self.vocab_act_rev = self.load_vocab_act(args.rom_path) vocab_size = len(self.sp) self.vocab_size_act = len(self.vocab_act.keys()) self.template_generator = TemplateActionGenerator(self.binding) self.template_size = len(self.template_generator.templates) self.num_steps = args.steps
class JerichoEnv(FrotzEnv): ''' Returns valid actions at each step of the game. ''' def __init__(self, rom_path, seed, step_limit=None): super(JerichoEnv, self).__init__(rom_path) self.rom_path = rom_path self.bindings = load_bindings(rom_path) # load_bindings(rom_path) self.act_gen = TemplateActionGenerator(self.bindings) self.seed = seed self.steps = 0 self.step_limit = step_limit self.env = None # self.bindings = None self.conn = None self.vocab_rev = None def create(self): self.env = FrotzEnv(self.rom_path, self.seed) # self.bindings = self.env.bindings self.vocab_rev = load_vocab_rev(self.env) self.conn = redis.Redis(host='localhost', port=6379, db=0) self.conn.flushdb() def get_valid(self, ob): # Get the valid actions for this state world_state_hash = self.env.get_world_state_hash() valid = self.conn.get(world_state_hash) if valid is None: objs = [o[0] for o in self.env.identify_interactive_objects(ob)] obj_ids = [ self.vocab_rev[o[:self.bindings['max_word_length']]] for o in objs ] acts = self.act_gen.generate_template_actions(objs, obj_ids) valid = self.env.find_valid_actions(acts) # valid = self.env.get_valid_actions(ob) redis_valid_value = '<SEP>'.join([str(a) for a in valid]) self.conn.set(world_state_hash, redis_valid_value) valid = [a.action for a in valid] else: valid = valid.decode('cp1252') if valid: valid = [eval(a).action for a in valid.split('<SEP>')] # valid = valid.split('<SEP>') else: valid = [] if len(valid) == 0: valid = ['wait', 'yes', 'no'] return valid def step(self, action, valid_out=True): ob, reward, done, info = self.env.step(action) # Initialize with default values info['look'] = 'unknown' info['inv'] = 'unknown' info['valid'] = ['wait', 'yes', 'no'] if not done: try: save = self.env.save_str() # save = self.env.get_state() look, _, _, _ = self.env.step('look') info['look'] = look self.env.load_str(save) # self.env.set_state(save) inv, _, _, _ = self.env.step('inventory') info['inv'] = inv self.env.load_str(save) # self.env.set_state(save) if valid_out: # Get the valid actions for this state world_state_hash = self.env.get_world_state_hash() valid = self.conn.get(world_state_hash) if valid is None: objs = [ o[0] for o in self.env.identify_interactive_objects(ob) ] obj_ids = [ self.vocab_rev[ o[:self.bindings['max_word_length']]] for o in objs ] acts = self.act_gen.generate_template_actions( objs, obj_ids) valid = self.env.find_valid_actions(acts) # valid = self.env.get_valid_actions(ob + ' | ' + look + ' | ' + inv) redis_valid_value = '<SEP>'.join( [str(a) for a in valid]) self.conn.set(world_state_hash, redis_valid_value) valid = [a.action for a in valid] else: valid = valid.decode('cp1252') if valid: # valid = [a for a in valid.split('|') # valid = valid.split('<SEP>') valid = [ eval(a).action for a in valid.split('<SEP>') ] else: valid = [] if len(valid) == 0: valid = ['wait', 'yes', 'no'] info['valid'] = valid except RuntimeError: print('RuntimeError: {}, Done: {}, Info: {}'.format( clean(ob), done, info)) self.steps += 1 if self.step_limit and self.steps >= self.step_limit: done = True return ob, reward, done, info def reset(self): initial_ob, info = self.env.reset() # import IPython; IPython.embed() save = self.env.save_str() # save = self.env.get_state() look, _, _, _ = self.env.step('look') info['look'] = look self.env.load_str(save) # self.env.set_state(save) inv, _, _, _ = self.env.step('inventory') info['inv'] = inv self.env.load_str(save) # self.env.set_state(save) objs = [ o[0] for o in self.env.identify_interactive_objects(initial_ob) ] acts = self.act_gen.generate_actions(objs) valid = self.env.find_valid_actions(acts) # valid = self.env.get_valid_actions(initial_ob + ' | ' + look + ' | ' + inv) info['valid'] = valid self.steps = 0 return initial_ob, info def copy(self): copy_env = JerichoEnv(self.rom_path, self.seed) copy_env.env = self.env.copy() copy_env.conn = self.conn copy_env.vocab_rev = load_vocab_rev(self.env) # copy_env.vocab_rev = self.vocab_rev.copy() return copy_env def get_state(self): state = self.env.get_state() conn = self.conn step_limit = self.step_limit vocab_rev = self.vocab_rev.copy() return state, conn, step_limit, vocab_rev def set_state(self, states): state, conn, step_limit, vocab_rev = states self.env.set_state(state) self.conn = conn self.step_limit = step_limit self.vocab_rev = vocab_rev def get_dictionary(self): if not self.env: self.create() return self.env.get_dictionary() def get_action_set(self): return None def close(self): self.env.close()
] #ex. ['mailbox', 'boarded', 'white'] candidate_actions = act_gen.generate_actions(interactive_objs) #ex. ['drive boarded', 'swim in mailbox', 'jump white', 'kick boarded','pour white in boarded', ... ] valid_actions = env.find_valid_actions(candidate_actions) chosen_action = random.choice(valid_actions) if args.verbose: print(colored(("valid actions:", valid_actions), 'green')) print(colored(("chosen action:", chosen_action), 'red')) return chosen_action args = parse_args() env = FrotzEnv(args.rom_path, seed=12) bindings = load_bindings(args.rom_path) act_gen = TemplateActionGenerator(bindings) obs, info = env.reset() done = False while not done: if args.debug: import ipdb ipdb.set_trace() args.max_iterations -= 1 if (args.max_iterations < 1): print(colored('Max Iterations Exceeded -- BREAK', 'red')) break # Use TemplateActionGenerator to select an action chosen_action = get_random_action(env, act_gen, args)
class JerichoEnv: def __init__(self, rom_path, seed, step_limit=None, env_num=8): self.rom_path = rom_path self.seed = seed self.step_limit = step_limit self.bindings = load_bindings(rom_path) self.seed = self.bindings['seed'] # some additional templates here, could make it game specific # Note: changes to it may cause the template id inconsistent self.additional_templates = ['land'] self.act_gen = TemplateActionGenerator(self.bindings) self.act_gen.templates = list( set(self.act_gen.templates + self.additional_templates)) self.act_gen.templates.sort() self.id2template = None self.template2id = None self._compute_template() self.env = FrotzEnv(self.rom_path, self.seed) self.env_num = env_num self.ps = None self.envs = None self.remotes = None self.work_remotes = None self.parallel = False self._init_parallel_workers() self.steps = 0 self.max_word_len = self.bindings['max_word_length'] self.word2id = None self.id2word = None self.noun_words = None self._compute_vocab_act() self.state2valid_acts = {} def _init_parallel_workers(self): if self.env_num > 0: self.parallel = True self.envs = [ FrotzEnv(self.rom_path, self.seed) for _ in range(self.env_num) ] self.remotes, self.work_remotes = zip( *[Pipe() for _ in range(self.env_num)]) self.ps = [ Process(target=worker, args=(self.work_remotes[i], self.remotes[i], self.envs[i])) for i in range(self.env_num) ] for p in self.ps: p.daemon = True p.start() for remote in self.work_remotes: remote.close() def _compute_vocab_act(self): # loading vocab directly from Jericho env_dict = self.env.get_dictionary() vocab = {i + 2: str(v) for i, v in enumerate(env_dict)} vocab[0] = ' ' vocab[1] = '<s>' vocab_rev = {v: idx for idx, v in vocab.items()} self.word2id = vocab_rev self.id2word = vocab self.noun_words = set([w.word for w in env_dict if w.is_noun]) return def _compute_template(self): self.id2template = {} self.template2id = {} for i, t in enumerate(self.act_gen.templates): self.id2template[i] = t self.template2id[t] = i return def get_max_word_len(self): return self.max_word_len def get_dictionary(self): return self.env.get_dictionary() def get_bindings(self): return self.bindings def get_id2act_word(self): return self.id2word def get_id2template(self): return self.id2template def get_template2id(self): return self.template2id def get_act_word2id(self): return self.word2id def close(self): self.env.close() for env in self.envs: env.close() for remote in self.remotes: remote.close() def tmpl_to_str(self, template_idx, o1_id, o2_id): template_str = self.act_gen.templates[template_idx] holes = template_str.count('OBJ') assert holes <= 2 if holes <= 0: return template_str elif holes == 1: return template_str.replace('OBJ', self.id2word[o1_id]) else: return (template_str.replace('OBJ', self.id2word[o1_id], 1).replace('OBJ', self.id2word[o2_id], 1)) def get_world_state_hash(self, ignore=False): if ignore: return self.env.get_world_state_hash() return _get_world_state_hash(self.env) def _identify_objects_on_current_state(self, ob): objs_raw = self.env.identify_interactive_objects(ob) objs_raw = list(itertools.chain.from_iterable(objs_raw)) objs_raw.sort() objs = [] obj_ids = [] for obj in objs_raw: if obj[:self.max_word_len] in self.noun_words: obj_id = self.word2id[obj[:self.max_word_len]] if obj_id not in obj_ids: objs.append(obj) obj_ids.append(obj_id) return objs, obj_ids def _generate_all_template_actions(self, objs, obj_ids): return self.act_gen.generate_template_actions(objs, obj_ids) def _check_valid_action_parallel(self, candidate_actions): chunks = [[ act for act_id, act in enumerate(candidate_actions) if act_id % self.env_num == i ] for i in range(self.env_num)] state = self.env.get_state() for i in range(self.env_num): self.remotes[i].send((state, chunks[i])) results = [remote.recv() for remote in self.remotes] flatten = lambda l: [item for sublist in l for item in sublist] keys = list(set(flatten([out.keys() for out in results]))) keys.sort() # merge key-value pairs and sort by valid_actions = [ flatten([results[i][key] for i in range(self.env_num)]) for key in keys ] for v in valid_actions: v.sort(key=lambda x: (x.template_id, tuple(x.obj_ids))) valid_actions = [[v[0]] for v in valid_actions] return valid_actions # the serial version copied from jericho, as parallel target def _check_valid_action_serial(self, candidate_actions): diff2acts = defaultdict(list) orig_score = self.env.get_score() state = self.env.get_state() for act in candidate_actions: self.env.set_state(state) if isinstance(act, defines.TemplateAction): obs, rew, done, info = self.env.step(act.action) else: obs, rew, done, info = self.env.step(act) if self.env.emulator_halted(): self.env.reset() continue if info['score'] != orig_score or done or self.env.world_changed(): # Heuristic to ignore actions with side-effect of taking items if '(Taken)' in obs: continue diff = self.env._get_world_diff() diff2acts[diff].append(act) # different treatment for return structure keys = list(diff2acts.keys()) keys.sort() valid_acts = [diff2acts[key] for key in keys] for v in valid_acts: v.sort(key=lambda x: (x.template_id, tuple(x.obj_ids))) valid_acts = [[v[0]] for v in valid_acts] self.env.set_state(state) return valid_acts def _get_action_on_current_state(self, state_hash, ob, parallel, compute_actions): valid_ao = None if state_hash in self.state2valid_acts: valid_ao = self.state2valid_acts[state_hash] # valid_ao = self.state2valid_acts[ # state_hash] if state_hash in self.state2valid_acts else None if valid_ao is None and compute_actions: # Identifies objects in the current location and inventory # that are likely to be interactive. # the returned obj may be a list, first one is the noun, # followed by some adj., objs, obj_ids = self._identify_objects_on_current_state(ob) acts = self._generate_all_template_actions(objs, obj_ids) if parallel: valid_acts = self._check_valid_action_parallel(acts) else: valid_acts = self._check_valid_action_serial(acts) # also compute what actions are removed # v_acts = [a for subacts in valid_acts for a in subacts] # invalid_act = [act for act in acts if act not in v_acts] invalid_act = [] valid_ao = (valid_acts, acts, objs, invalid_act) self.state2valid_acts[state_hash] = valid_ao if state_hash in self.state2valid_acts: return self.state2valid_acts[state_hash] return [], [], [], [] def step(self, action, confidence=0, parallel=True, compute_actions=True): ob, reward, done, info = self.env.step(action) world_changed = self.env.world_changed() info['world_changed'] = world_changed # Initialize with default values look = 'unknown' inv = 'unknown' if not done: try: # it is still possible the actions' effect is stochastic, # so the world changed is also a random variable # reduce the effect of randomness by repeating # (similar to some ATARI games) if not world_changed and confidence > 0: save = self.env.get_state() for _ in range(confidence): _, _, _, _ = self.env.step(action) world_changed = self.env.world_changed() if world_changed: break self.env.set_state(save) info['world_changed'] = world_changed self.steps += 1 state_hash = _get_world_state_hash(self.env) save = self.env.get_state() look, _, _, _ = self.env.step('look') self.env.set_state(save) inv, _, _, _ = self.env.step('inventory') self.env.set_state(save) # Find Valid actions act_info = self._get_action_on_current_state( state_hash, ob, parallel, compute_actions) info['valid_act'] = act_info[0] info['act'] = act_info[1] info['objs'] = act_info[2] info['invalid_act'] = act_info[3] except RuntimeError: print('RuntimeError: {}, Done: {}, Info: {}'.format( clean_obs(ob), done, info)) info['valid_act'] = [] info['act'] = [] info['objs'] = [] info['invalid_act'] = [] done = True else: info['valid_act'] = [] info['act'] = [] info['objs'] = [] info['invalid_act'] = [] if self.step_limit and self.steps >= self.step_limit: done = True ob = (clean_obs(look) + '|' + clean_obs(inv) + '|' + clean_obs(ob) + '|' + clean_obs(action)) return ob, reward, done, info def reset(self, parallel=True, compute_actions=True): initial_ob, info = self.env.reset() save = self.env.get_state() self.steps = 0 look, inv = '', '' try: state_hash = _get_world_state_hash(self.env) look, _, _, _ = self.env.step('look') self.env.set_state(save) inv, _, _, _ = self.env.step('inventory') self.env.set_state(save) # compute valid state for initial obs act_info = self._get_action_on_current_state( state_hash, initial_ob, parallel, compute_actions) if len(act_info[0]) == 0: done = True info['valid_act'] = act_info[0] info['act'] = act_info[1] info['objs'] = act_info[2] info['invalid_act'] = act_info[3] info['world_changed'] = True except RuntimeError: print('RuntimeError: {}, Info: {}'.format(initial_ob, info)) info['valid_act'] = [] info['act'] = [] info['objs'] = [] info['invalid_act'] = [] info['world_changed'] = True self.steps = self.step_limit initial_ob = clean_obs(look) + '|' + clean_obs(inv) + '|' + clean_obs( initial_ob) + '|' + clean_obs('look') return initial_ob, info def align_action_on_current_state(self, target_action, action_groups): state = self.env.get_state() obs, rew, done, info = self.env.step(target_action) target_diff = self.env._get_world_diff() orig_score = info['score'] self.env.set_state(state) # print('act_group', action_groups) for id, act in enumerate(action_groups): # the first action as the archetype action act = act[0] template_id = act[1] obj1_id = act[2][0] if len(act[2]) > 0 else None obj2_id = act[2][1] if len(act[2]) > 1 else None # act_str = self.tmpl_to_str(template_id, obj1_id, obj2_id) obs, rew, done, info = self.env.step(act_str) # if self.env.emulator_halted(): # self.env.reset() # continue diff = None if info['score'] != orig_score or done or self.env.world_changed(): # if '(Taken)' in obs: # continue diff = self.env._get_world_diff() if diff == target_diff: self.env.set_state(state) return id self.env.set_state(state) return -1
class JerichoEnv: def __init__(self, rom_path, seed, vocab_rev, step_limit=None): self.rom_path = rom_path self.bindings = load_bindings(rom_path) self.act_gen = TemplateActionGenerator(self.bindings) self.seed = seed self.steps = 0 self.step_limit = step_limit self.vocab_rev = vocab_rev self.env = None self.conn = None def create(self): self.env = FrotzEnv(self.rom_path, self.seed) start_redis() self.conn = redis.Redis(host='localhost', port=6379, db=0) self.conn.flushdb() def step(self, action): ob, reward, done, info = self.env.step(action) action_valid = done or self.env.world_changed() info['action_valid'] = action_valid if not action_valid: # Exit early for invalid actions return None, None, None, info if action_valid: self.steps += 1 # Initialize with default values look = 'unknown' inv = 'unknown' info['valid'] = [] if not done: try: save = self.env.save_str() look, _, _, _ = self.env.step('look') self.env.load_str(save) inv, _, _, _ = self.env.step('inventory') self.env.load_str(save) # Find Valid actions world_state_hash = self.env.get_world_state_hash() valid = self.conn.get(world_state_hash) if valid is None: objs = [ o[0] for o in self.env.identify_interactive_objects(ob) ] obj_ids = [ self.vocab_rev[o[:self.bindings['max_word_length']]] for o in objs ] acts = self.act_gen.generate_template_actions( objs, obj_ids) valid = self.env.find_valid_actions(acts) redis_valid_value = '/'.join([str(a) for a in valid]) self.conn.set(world_state_hash, redis_valid_value) else: valid = valid.decode('cp1252') if valid: valid = [eval(a) for a in valid.split('/')] info['valid'] = valid except RuntimeError: print('RuntimeError: {}, Done: {}, Info: {}'.format( clean(ob), done, info)) if self.step_limit and self.steps >= self.step_limit: done = True ob = look + '|' + inv + '|' + ob + '|' + action return ob, reward, done, info def reset(self): initial_ob, info = self.env.reset() try: save = self.env.save_str() look, _, _, _ = self.env.step('look') self.env.load_str(save) inv, _, _, _ = self.env.step('inventory') self.env.load_str(save) except RuntimeError: print('RuntimeError: {}, Info: {}'.format(initial_ob, info)) look, inv = '' self.steps = 0 initial_ob = look + '|' + inv + '|' + initial_ob + '|' + 'look' return initial_ob, info def get_dictionary(self): if not self.env: self.create() return self.env.get_dictionary() def get_action_set(self): return None def close(self): self.env.close() self.conn.shutdown(save=True)