Ejemplo n.º 1
0
 def __init__(self, rom_path, seed, step_limit=None):
     self.rom_path = rom_path
     self.bindings = load_bindings(rom_path)
     self.act_gen = TemplateActionGenerator(self.bindings)
     self.seed = seed
     self.steps = 0
     self.step_limit = step_limit
     self.env = None
     self.conn = None
     self.vocab_rev = None
Ejemplo n.º 2
0
    def __init__(self, params):
        configure_logger(params['output_dir'])
        log('Parameters {}'.format(params))
        self.params = params
        self.binding = load_bindings(params['rom_file_path'])
        self.max_word_length = self.binding['max_word_length']
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(params['spm_file'])
        kg_env = KGA2CEnv(params['rom_file_path'],
                          params['seed'],
                          self.sp,
                          params['tsv_file'],
                          step_limit=params['reset_steps'],
                          stuck_steps=params['stuck_steps'],
                          gat=params['gat'])
        self.vec_env = VecEnv(params['batch_size'], kg_env,
                              params['openie_path'])
        self.template_generator = TemplateActionGenerator(self.binding)
        env = FrotzEnv(params['rom_file_path'])
        self.vocab_act, self.vocab_act_rev = load_vocab(env)
        self.model = KGA2C(params,
                           self.template_generator.templates,
                           self.max_word_length,
                           self.vocab_act,
                           self.vocab_act_rev,
                           len(self.sp),
                           gat=self.params['gat']).cuda()
        self.batch_size = params['batch_size']
        if params['preload_weights']:
            self.model = torch.load(self.params['preload_weights'])['model']
        self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr'])

        self.loss_fn1 = nn.BCELoss()
        self.loss_fn2 = nn.BCEWithLogitsLoss()
        self.loss_fn3 = nn.MSELoss()
Ejemplo n.º 3
0
def test_valid_action_identification():
    rom_path = pjoin(DATA_PATH, '905.z5')
    env = jericho.FrotzEnv(rom_path)
    bindings = jericho.load_bindings(rom_path)
    act_gen = TemplateActionGenerator(bindings)
    obs, info = env.reset()
    # interactive_objs = [obj[0] for obj in env.identify_interactive_objects(use_object_tree=True)]
    interactive_objs = ['phone', 'keys', 'wallet']
    candidate_actions = act_gen.generate_actions(interactive_objs)

    valid = env.find_valid_actions(candidate_actions)
    assert 'take wallet' in valid
    assert 'open wallet' in valid
    assert 'take keys' in valid
    assert 'get up' in valid
    assert 'take phone' in valid
Ejemplo n.º 4
0
 def create(self):
     self.env = FrotzEnv(self.rom_path, self.seed)
     self.bindings = self.env.bindings
     self.act_gen = TemplateActionGenerator(self.bindings)
     self.vocab_rev = load_vocab_rev(self.env)
     self.conn = redis.Redis(host='localhost', port=6379, db=0)
     self.conn.flushdb()
Ejemplo n.º 5
0
    def __init__(self, params):
        torch.manual_seed(params['seed'])
        np.random.seed(params['seed'])
        random.seed(params['seed'])
        configure_logger(params['output_dir'])
        log('Parameters {}'.format(params))
        self.params = params
        self.chkpt_path = os.path.dirname(self.params['checkpoint_path'])
        if not os.path.exists(self.chkpt_path):
            os.mkdir(self.chkpt_path)
        self.binding = load_bindings(params['rom_file_path'])
        self.max_word_length = self.binding['max_word_length']
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(params['spm_file'])
        #askbert_args = {'input_text': '', 'length': 10, 'batch_size': 1, 'temperature': 1, 'model_name': '117M',
        #                'seed': 0, 'nsamples': 10, 'cutoffs': "6.5 -7 -5", 'write_sfdp': False, 'random': False}
        #self.extraction = kgextraction.World([], [], [], askbert_args)
        self.askbert = params['extraction']
        kg_env = QBERTEnv(params['rom_file_path'],
                          params['seed'],
                          self.sp,
                          params['tsv_file'],
                          params['attr_file'],
                          step_limit=params['reset_steps'],
                          stuck_steps=params['stuck_steps'],
                          gat=params['gat'],
                          askbert=self.askbert,
                          clear_kg=params['clear_kg_on_reset'])

        self.vec_env = VecEnv(params['batch_size'], kg_env,
                              params['openie_path'], params['redis_db_path'],
                              params['buffer_size'], params['extraction'],
                              params['training_type'],
                              params['clear_kg_on_reset'])
        self.template_generator = TemplateActionGenerator(self.binding)
        env = FrotzEnv(params['rom_file_path'])
        self.max_game_score = env.get_max_score()
        self.cur_reload_state = env.get_state()
        self.vocab_act, self.vocab_act_rev = load_vocab(env)
        self.model = QBERT(params,
                           self.template_generator.templates,
                           self.max_word_length,
                           self.vocab_act,
                           self.vocab_act_rev,
                           len(self.sp),
                           gat=self.params['gat']).cuda()
        self.batch_size = params['batch_size']
        if params['preload_weights']:
            self.model = torch.load(self.params['preload_weights'])['model']
        self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr'])

        self.loss_fn1 = nn.BCELoss()
        self.loss_fn2 = nn.BCEWithLogitsLoss()
        self.loss_fn3 = nn.MSELoss()

        self.chained_logger = params['chained_logger']
        self.total_steps = 0
Ejemplo n.º 6
0
    def __init__(self, params, args):
        configure_logger(params['output_dir'])
        log('Parameters {}'.format(params))
        self.params = params
        self.binding = load_bindings(params['rom_file_path'])
        self.max_word_length = self.binding['max_word_length']
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(params['spm_file'])
        self.use_cs = self.params['use_cs']
        if (self.use_cs == True):
            print("Using COMET")
            self.kg_extract = CometHelper(args)
        kg_env = KGA2CEnv(params['rom_file_path'],
                          params['seed'],
                          self.sp,
                          params['tsv_file'],
                          step_limit=params['reset_steps'],
                          stuck_steps=params['stuck_steps'],
                          gat=params['gat'])

        self.vec_env = VecEnv(params['batch_size'], kg_env,
                              params['openie_path'])
        self.template_generator = TemplateActionGenerator(self.binding)
        env = FrotzEnv(params['rom_file_path'])
        self.vocab_act, self.vocab_act_rev = load_vocab(env)
        torch.cuda.set_device(int(self.params['device_a2c']))
        # self.model = KGA2C(params, self.template_generator.templates, self.max_word_length,
        #                    self.vocab_act, self.vocab_act_rev, len(self.sp), a2c_device=(int(self.params['device_a2c'])),
        #                    bert_device =int(self.params['device_bert']),
        #                    gat=self.params['gat'])
        self.model = KGA2C(params,
                           self.template_generator.templates,
                           self.max_word_length,
                           self.vocab_act,
                           self.vocab_act_rev,
                           len(self.sp),
                           a2c_device=(int(self.params['device_a2c'])),
                           gat=self.params['gat'])

        # print(torch.cuda.current_device())
        self.batch_size = params['batch_size']
        if params['preload_weights']:
            self.model = torch.load(self.params['preload_weights'])['model']
        self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr'])

        self.loss_fn1 = nn.BCELoss()
        self.loss_fn2 = nn.BCEWithLogitsLoss()
        self.loss_fn3 = nn.MSELoss()

        self.args = args
Ejemplo n.º 7
0
    def __init__(self, rom_path, seed, step_limit=None, env_num=8):
        self.rom_path = rom_path
        self.seed = seed

        self.step_limit = step_limit

        self.bindings = load_bindings(rom_path)
        self.seed = self.bindings['seed']
        # some additional templates here, could make it game specific
        # Note: changes to it may cause the template id inconsistent
        self.additional_templates = ['land']
        self.act_gen = TemplateActionGenerator(self.bindings)
        self.act_gen.templates = list(
            set(self.act_gen.templates + self.additional_templates))
        self.act_gen.templates.sort()
        self.id2template = None
        self.template2id = None
        self._compute_template()
        self.env = FrotzEnv(self.rom_path, self.seed)

        self.env_num = env_num
        self.ps = None
        self.envs = None
        self.remotes = None
        self.work_remotes = None
        self.parallel = False
        self._init_parallel_workers()

        self.steps = 0
        self.max_word_len = self.bindings['max_word_length']

        self.word2id = None
        self.id2word = None
        self.noun_words = None
        self._compute_vocab_act()
        self.state2valid_acts = {}
Ejemplo n.º 8
0
    def __init__(self, args):
        configure_logger(args.output_dir)
        log(args)
        self.args = args

        self.log_freq = args.log_freq
        self.update_freq = args.update_freq_td
        self.update_freq_tar = args.update_freq_tar

        self.filename = 'ptdqn' + args.rom_path + str(args.run_number)
        wandb.init(project="my-project", name=self.filename)

        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(args.spm_path)
        self.binding = jericho.load_bindings(args.rom_path)
        self.vocab_act, self.vocab_act_rev = self.load_vocab_act(args.rom_path)
        vocab_size = len(self.sp)
        vocab_size_act = len(self.vocab_act.keys())

        self.template_generator = TemplateActionGenerator(self.binding)
        self.template_size = len(self.template_generator.templates)

        if args.replay_buffer_type == 'priority':
            self.replay_buffer = PriorityReplayBuffer(
                int(args.replay_buffer_size))
        elif args.replay_buffer_type == 'standard':
            self.replay_buffer = ReplayBuffer(int(args.replay_buffer_size))
        self.action_size = self.template_size
        self.action_parameter_size = vocab_size_act
        self.model = TDQN(args, self.action_size, self.action_parameter_size,
                          self.template_size, vocab_size,
                          vocab_size_act).cuda()
        self.target_model = TDQN(args, self.action_size,
                                 self.action_parameter_size,
                                 self.template_size, vocab_size,
                                 vocab_size_act).cuda()

        self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)

        self.num_steps = args.steps
        self.batch_size = args.batch_size
        self.gamma = args.gamma

        self.rho = args.rho
        self.vocab_size_act = vocab_size_act

        self.bce_loss = nn.BCELoss()
Ejemplo n.º 9
0
 def __init__(self, params):
     print("----- Initiating ----- ")
     print("----- step 1 configure logger")
     configure_logger(params['output_dir'])
     log('Parameters {}'.format(params))
     self.params = params
     print("----- step 2 load pre-collected things")
     self.binding = load_bindings(params['rom_file_path'])
     self.max_word_length = self.binding['max_word_length']
     self.sp = spm.SentencePieceProcessor()
     self.sp.Load(params['spm_file'])
     print("----- step 3 build KGA2CEnv")
     kg_env = KGA2CEnv(params['rom_file_path'],
                       params['seed'],
                       self.sp,
                       params['tsv_file'],
                       step_limit=params['reset_steps'],
                       stuck_steps=params['stuck_steps'],
                       gat=params['gat'])
     self.vec_env = VecEnv(params['batch_size'], kg_env,
                           params['openie_path'])
     print("----- step 4 build FrotzEnv and templace generator")
     env = FrotzEnv(params['rom_file_path'])
     self.vocab_act, self.vocab_act_rev = load_vocab(env)
     self.template_generator = TemplateActionGenerator(self.binding)
     print("----- step 5 build kga2c model")
     self.model = KGA2C(params,
                        self.template_generator.templates,
                        self.max_word_length,
                        self.vocab_act,
                        self.vocab_act_rev,
                        len(self.sp),
                        gat=self.params['gat']).cuda()
     if params['preload_weights']:
         print("load pretrained")
         self.model = torch.load(self.params['preload_weights'])['model']
     else:
         print("train from scratch")
     print("----- step 6 set training parameters")
     self.batch_size = params['batch_size']
     self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr'])
     self.loss_fn1 = nn.BCELoss()
     self.loss_fn2 = nn.BCEWithLogitsLoss()
     self.loss_fn3 = nn.MSELoss()
     print("----- Init finished! ----- ")
Ejemplo n.º 10
0
    def __init__(self, args):
        self.args = args

        self.log_freq = args.log_freq
        self.update_freq = args.update_freq_td
        self.update_freq_tar = args.update_freq_tar
        self.filename = 'random' + args.rom_path + str(args.run_number)
        wandb.init(project="my-project", name=self.filename)
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(args.spm_path)
        self.binding = jericho.load_bindings(args.rom_path)
        self.vocab_act, self.vocab_act_rev = self.load_vocab_act(args.rom_path)
        vocab_size = len(self.sp)
        self.vocab_size_act = len(self.vocab_act.keys())

        self.template_generator = TemplateActionGenerator(self.binding)
        self.template_size = len(self.template_generator.templates)

        self.num_steps = args.steps
Ejemplo n.º 11
0
class JerichoEnv(FrotzEnv):
    ''' Returns valid actions at each step of the game. '''
    def __init__(self, rom_path, seed, step_limit=None):
        super(JerichoEnv, self).__init__(rom_path)
        self.rom_path = rom_path
        self.bindings = load_bindings(rom_path)
        # load_bindings(rom_path)
        self.act_gen = TemplateActionGenerator(self.bindings)

        self.seed = seed
        self.steps = 0
        self.step_limit = step_limit
        self.env = None
        # self.bindings = None
        self.conn = None
        self.vocab_rev = None

    def create(self):
        self.env = FrotzEnv(self.rom_path, self.seed)
        # self.bindings = self.env.bindings
        self.vocab_rev = load_vocab_rev(self.env)
        self.conn = redis.Redis(host='localhost', port=6379, db=0)
        self.conn.flushdb()

    def get_valid(self, ob):
        # Get the valid actions for this state
        world_state_hash = self.env.get_world_state_hash()
        valid = self.conn.get(world_state_hash)
        if valid is None:
            objs = [o[0] for o in self.env.identify_interactive_objects(ob)]
            obj_ids = [
                self.vocab_rev[o[:self.bindings['max_word_length']]]
                for o in objs
            ]
            acts = self.act_gen.generate_template_actions(objs, obj_ids)
            valid = self.env.find_valid_actions(acts)
            # valid = self.env.get_valid_actions(ob)
            redis_valid_value = '<SEP>'.join([str(a) for a in valid])
            self.conn.set(world_state_hash, redis_valid_value)
            valid = [a.action for a in valid]
        else:
            valid = valid.decode('cp1252')
            if valid:
                valid = [eval(a).action for a in valid.split('<SEP>')]
                # valid = valid.split('<SEP>')
            else:
                valid = []
        if len(valid) == 0:
            valid = ['wait', 'yes', 'no']
        return valid

    def step(self, action, valid_out=True):
        ob, reward, done, info = self.env.step(action)
        # Initialize with default values
        info['look'] = 'unknown'
        info['inv'] = 'unknown'
        info['valid'] = ['wait', 'yes', 'no']
        if not done:
            try:
                save = self.env.save_str()
                # save = self.env.get_state()
                look, _, _, _ = self.env.step('look')
                info['look'] = look
                self.env.load_str(save)
                # self.env.set_state(save)
                inv, _, _, _ = self.env.step('inventory')
                info['inv'] = inv
                self.env.load_str(save)
                # self.env.set_state(save)

                if valid_out:
                    # Get the valid actions for this state
                    world_state_hash = self.env.get_world_state_hash()
                    valid = self.conn.get(world_state_hash)
                    if valid is None:
                        objs = [
                            o[0]
                            for o in self.env.identify_interactive_objects(ob)
                        ]
                        obj_ids = [
                            self.vocab_rev[
                                o[:self.bindings['max_word_length']]]
                            for o in objs
                        ]
                        acts = self.act_gen.generate_template_actions(
                            objs, obj_ids)
                        valid = self.env.find_valid_actions(acts)
                        # valid = self.env.get_valid_actions(ob + ' | ' + look + ' | ' + inv)
                        redis_valid_value = '<SEP>'.join(
                            [str(a) for a in valid])
                        self.conn.set(world_state_hash, redis_valid_value)
                        valid = [a.action for a in valid]

                    else:
                        valid = valid.decode('cp1252')
                        if valid:
                            # valid = [a for a in valid.split('|')
                            # valid = valid.split('<SEP>')
                            valid = [
                                eval(a).action for a in valid.split('<SEP>')
                            ]
                        else:
                            valid = []
                    if len(valid) == 0:
                        valid = ['wait', 'yes', 'no']
                    info['valid'] = valid
            except RuntimeError:
                print('RuntimeError: {}, Done: {}, Info: {}'.format(
                    clean(ob), done, info))
        self.steps += 1
        if self.step_limit and self.steps >= self.step_limit:
            done = True
        return ob, reward, done, info

    def reset(self):
        initial_ob, info = self.env.reset()
        # import IPython; IPython.embed()
        save = self.env.save_str()
        # save = self.env.get_state()
        look, _, _, _ = self.env.step('look')
        info['look'] = look
        self.env.load_str(save)
        # self.env.set_state(save)
        inv, _, _, _ = self.env.step('inventory')
        info['inv'] = inv
        self.env.load_str(save)
        # self.env.set_state(save)
        objs = [
            o[0] for o in self.env.identify_interactive_objects(initial_ob)
        ]
        acts = self.act_gen.generate_actions(objs)
        valid = self.env.find_valid_actions(acts)
        # valid = self.env.get_valid_actions(initial_ob + ' | ' + look + ' | ' + inv)

        info['valid'] = valid
        self.steps = 0
        return initial_ob, info

    def copy(self):
        copy_env = JerichoEnv(self.rom_path, self.seed)
        copy_env.env = self.env.copy()
        copy_env.conn = self.conn
        copy_env.vocab_rev = load_vocab_rev(self.env)
        # copy_env.vocab_rev = self.vocab_rev.copy()
        return copy_env

    def get_state(self):
        state = self.env.get_state()
        conn = self.conn
        step_limit = self.step_limit
        vocab_rev = self.vocab_rev.copy()
        return state, conn, step_limit, vocab_rev

    def set_state(self, states):
        state, conn, step_limit, vocab_rev = states
        self.env.set_state(state)
        self.conn = conn
        self.step_limit = step_limit
        self.vocab_rev = vocab_rev

    def get_dictionary(self):
        if not self.env:
            self.create()
        return self.env.get_dictionary()

    def get_action_set(self):
        return None

    def close(self):
        self.env.close()
Ejemplo n.º 12
0
    ]
    #ex. ['mailbox', 'boarded', 'white']
    candidate_actions = act_gen.generate_actions(interactive_objs)
    #ex. ['drive boarded', 'swim in mailbox', 'jump white', 'kick boarded','pour white in boarded', ... ]
    valid_actions = env.find_valid_actions(candidate_actions)
    chosen_action = random.choice(valid_actions)
    if args.verbose:
        print(colored(("valid actions:", valid_actions), 'green'))
        print(colored(("chosen action:", chosen_action), 'red'))
    return chosen_action


args = parse_args()
env = FrotzEnv(args.rom_path, seed=12)
bindings = load_bindings(args.rom_path)
act_gen = TemplateActionGenerator(bindings)
obs, info = env.reset()
done = False

while not done:
    if args.debug:
        import ipdb
        ipdb.set_trace()
    args.max_iterations -= 1
    if (args.max_iterations < 1):
        print(colored('Max Iterations Exceeded -- BREAK', 'red'))
        break

    # Use TemplateActionGenerator to select an action
    chosen_action = get_random_action(env, act_gen, args)
Ejemplo n.º 13
0
class JerichoEnv:
    def __init__(self, rom_path, seed, step_limit=None, env_num=8):
        self.rom_path = rom_path
        self.seed = seed

        self.step_limit = step_limit

        self.bindings = load_bindings(rom_path)
        self.seed = self.bindings['seed']
        # some additional templates here, could make it game specific
        # Note: changes to it may cause the template id inconsistent
        self.additional_templates = ['land']
        self.act_gen = TemplateActionGenerator(self.bindings)
        self.act_gen.templates = list(
            set(self.act_gen.templates + self.additional_templates))
        self.act_gen.templates.sort()
        self.id2template = None
        self.template2id = None
        self._compute_template()
        self.env = FrotzEnv(self.rom_path, self.seed)

        self.env_num = env_num
        self.ps = None
        self.envs = None
        self.remotes = None
        self.work_remotes = None
        self.parallel = False
        self._init_parallel_workers()

        self.steps = 0
        self.max_word_len = self.bindings['max_word_length']

        self.word2id = None
        self.id2word = None
        self.noun_words = None
        self._compute_vocab_act()
        self.state2valid_acts = {}

    def _init_parallel_workers(self):
        if self.env_num > 0:
            self.parallel = True
            self.envs = [
                FrotzEnv(self.rom_path, self.seed) for _ in range(self.env_num)
            ]
            self.remotes, self.work_remotes = zip(
                *[Pipe() for _ in range(self.env_num)])
            self.ps = [
                Process(target=worker,
                        args=(self.work_remotes[i], self.remotes[i],
                              self.envs[i])) for i in range(self.env_num)
            ]
            for p in self.ps:
                p.daemon = True
                p.start()
            for remote in self.work_remotes:
                remote.close()

    def _compute_vocab_act(self):
        # loading vocab directly from Jericho
        env_dict = self.env.get_dictionary()
        vocab = {i + 2: str(v) for i, v in enumerate(env_dict)}
        vocab[0] = ' '
        vocab[1] = '<s>'
        vocab_rev = {v: idx for idx, v in vocab.items()}
        self.word2id = vocab_rev
        self.id2word = vocab
        self.noun_words = set([w.word for w in env_dict if w.is_noun])
        return

    def _compute_template(self):
        self.id2template = {}
        self.template2id = {}
        for i, t in enumerate(self.act_gen.templates):
            self.id2template[i] = t
            self.template2id[t] = i
        return

    def get_max_word_len(self):
        return self.max_word_len

    def get_dictionary(self):
        return self.env.get_dictionary()

    def get_bindings(self):
        return self.bindings

    def get_id2act_word(self):
        return self.id2word

    def get_id2template(self):
        return self.id2template

    def get_template2id(self):
        return self.template2id

    def get_act_word2id(self):
        return self.word2id

    def close(self):
        self.env.close()
        for env in self.envs:
            env.close()
        for remote in self.remotes:
            remote.close()

    def tmpl_to_str(self, template_idx, o1_id, o2_id):
        template_str = self.act_gen.templates[template_idx]
        holes = template_str.count('OBJ')
        assert holes <= 2
        if holes <= 0:
            return template_str
        elif holes == 1:
            return template_str.replace('OBJ', self.id2word[o1_id])
        else:
            return (template_str.replace('OBJ', self.id2word[o1_id],
                                         1).replace('OBJ', self.id2word[o2_id],
                                                    1))

    def get_world_state_hash(self, ignore=False):
        if ignore:
            return self.env.get_world_state_hash()
        return _get_world_state_hash(self.env)

    def _identify_objects_on_current_state(self, ob):
        objs_raw = self.env.identify_interactive_objects(ob)
        objs_raw = list(itertools.chain.from_iterable(objs_raw))
        objs_raw.sort()
        objs = []
        obj_ids = []
        for obj in objs_raw:
            if obj[:self.max_word_len] in self.noun_words:
                obj_id = self.word2id[obj[:self.max_word_len]]
                if obj_id not in obj_ids:
                    objs.append(obj)
                    obj_ids.append(obj_id)

        return objs, obj_ids

    def _generate_all_template_actions(self, objs, obj_ids):
        return self.act_gen.generate_template_actions(objs, obj_ids)

    def _check_valid_action_parallel(self, candidate_actions):
        chunks = [[
            act for act_id, act in enumerate(candidate_actions)
            if act_id % self.env_num == i
        ] for i in range(self.env_num)]
        state = self.env.get_state()

        for i in range(self.env_num):
            self.remotes[i].send((state, chunks[i]))

        results = [remote.recv() for remote in self.remotes]

        flatten = lambda l: [item for sublist in l for item in sublist]

        keys = list(set(flatten([out.keys() for out in results])))
        keys.sort()

        # merge key-value pairs and sort by
        valid_actions = [
            flatten([results[i][key] for i in range(self.env_num)])
            for key in keys
        ]
        for v in valid_actions:
            v.sort(key=lambda x: (x.template_id, tuple(x.obj_ids)))
        valid_actions = [[v[0]] for v in valid_actions]
        return valid_actions

    # the serial version copied from jericho, as parallel target
    def _check_valid_action_serial(self, candidate_actions):
        diff2acts = defaultdict(list)
        orig_score = self.env.get_score()
        state = self.env.get_state()
        for act in candidate_actions:
            self.env.set_state(state)
            if isinstance(act, defines.TemplateAction):
                obs, rew, done, info = self.env.step(act.action)
            else:
                obs, rew, done, info = self.env.step(act)

            if self.env.emulator_halted():
                self.env.reset()
                continue

            if info['score'] != orig_score or done or self.env.world_changed():
                # Heuristic to ignore actions with side-effect of taking items
                if '(Taken)' in obs:
                    continue
                diff = self.env._get_world_diff()
                diff2acts[diff].append(act)
        # different treatment for return structure
        keys = list(diff2acts.keys())
        keys.sort()
        valid_acts = [diff2acts[key] for key in keys]
        for v in valid_acts:
            v.sort(key=lambda x: (x.template_id, tuple(x.obj_ids)))
        valid_acts = [[v[0]] for v in valid_acts]
        self.env.set_state(state)
        return valid_acts

    def _get_action_on_current_state(self, state_hash, ob, parallel,
                                     compute_actions):
        valid_ao = None
        if state_hash in self.state2valid_acts:
            valid_ao = self.state2valid_acts[state_hash]
        # valid_ao = self.state2valid_acts[
        #     state_hash] if state_hash in self.state2valid_acts else None
        if valid_ao is None and compute_actions:
            # Identifies objects in the current location and inventory
            # that are likely to be interactive.
            # the returned obj may be a list, first one is the noun,
            # followed by some adj.,
            objs, obj_ids = self._identify_objects_on_current_state(ob)
            acts = self._generate_all_template_actions(objs, obj_ids)

            if parallel:
                valid_acts = self._check_valid_action_parallel(acts)
            else:
                valid_acts = self._check_valid_action_serial(acts)
            # also compute what actions are removed
            # v_acts = [a for subacts in valid_acts for a in subacts]
            # invalid_act = [act for act in acts if act not in v_acts]
            invalid_act = []
            valid_ao = (valid_acts, acts, objs, invalid_act)
            self.state2valid_acts[state_hash] = valid_ao

        if state_hash in self.state2valid_acts:
            return self.state2valid_acts[state_hash]
        return [], [], [], []

    def step(self, action, confidence=0, parallel=True, compute_actions=True):
        ob, reward, done, info = self.env.step(action)
        world_changed = self.env.world_changed()
        info['world_changed'] = world_changed
        # Initialize with default values
        look = 'unknown'
        inv = 'unknown'
        if not done:
            try:
                # it is still possible the actions' effect is stochastic,
                # so the world changed is also a random variable
                # reduce the effect of randomness by repeating
                # (similar to some ATARI games)
                if not world_changed and confidence > 0:
                    save = self.env.get_state()
                    for _ in range(confidence):
                        _, _, _, _ = self.env.step(action)
                        world_changed = self.env.world_changed()
                        if world_changed:
                            break
                    self.env.set_state(save)

                info['world_changed'] = world_changed
                self.steps += 1

                state_hash = _get_world_state_hash(self.env)
                save = self.env.get_state()
                look, _, _, _ = self.env.step('look')
                self.env.set_state(save)
                inv, _, _, _ = self.env.step('inventory')
                self.env.set_state(save)
                # Find Valid actions
                act_info = self._get_action_on_current_state(
                    state_hash, ob, parallel, compute_actions)

                info['valid_act'] = act_info[0]
                info['act'] = act_info[1]
                info['objs'] = act_info[2]
                info['invalid_act'] = act_info[3]

            except RuntimeError:
                print('RuntimeError: {}, Done: {}, Info: {}'.format(
                    clean_obs(ob), done, info))
                info['valid_act'] = []
                info['act'] = []
                info['objs'] = []
                info['invalid_act'] = []
                done = True

        else:
            info['valid_act'] = []
            info['act'] = []
            info['objs'] = []
            info['invalid_act'] = []

        if self.step_limit and self.steps >= self.step_limit:
            done = True

        ob = (clean_obs(look) + '|' + clean_obs(inv) + '|' + clean_obs(ob) +
              '|' + clean_obs(action))
        return ob, reward, done, info

    def reset(self, parallel=True, compute_actions=True):
        initial_ob, info = self.env.reset()
        save = self.env.get_state()
        self.steps = 0
        look, inv = '', ''
        try:
            state_hash = _get_world_state_hash(self.env)
            look, _, _, _ = self.env.step('look')
            self.env.set_state(save)
            inv, _, _, _ = self.env.step('inventory')
            self.env.set_state(save)

            # compute valid state for initial obs
            act_info = self._get_action_on_current_state(
                state_hash, initial_ob, parallel, compute_actions)

            if len(act_info[0]) == 0:
                done = True
            info['valid_act'] = act_info[0]
            info['act'] = act_info[1]
            info['objs'] = act_info[2]
            info['invalid_act'] = act_info[3]
            info['world_changed'] = True

        except RuntimeError:
            print('RuntimeError: {}, Info: {}'.format(initial_ob, info))
            info['valid_act'] = []
            info['act'] = []
            info['objs'] = []
            info['invalid_act'] = []
            info['world_changed'] = True

            self.steps = self.step_limit

        initial_ob = clean_obs(look) + '|' + clean_obs(inv) + '|' + clean_obs(
            initial_ob) + '|' + clean_obs('look')
        return initial_ob, info

    def align_action_on_current_state(self, target_action, action_groups):
        state = self.env.get_state()
        obs, rew, done, info = self.env.step(target_action)
        target_diff = self.env._get_world_diff()
        orig_score = info['score']
        self.env.set_state(state)
        # print('act_group', action_groups)
        for id, act in enumerate(action_groups):
            # the first action as the archetype action
            act = act[0]
            template_id = act[1]
            obj1_id = act[2][0] if len(act[2]) > 0 else None
            obj2_id = act[2][1] if len(act[2]) > 1 else None

            #
            act_str = self.tmpl_to_str(template_id, obj1_id, obj2_id)
            obs, rew, done, info = self.env.step(act_str)
            # if self.env.emulator_halted():
            #     self.env.reset()
            #     continue
            diff = None
            if info['score'] != orig_score or done or self.env.world_changed():
                # if '(Taken)' in obs:
                #     continue
                diff = self.env._get_world_diff()

            if diff == target_diff:
                self.env.set_state(state)
                return id

            self.env.set_state(state)

        return -1
Ejemplo n.º 14
0
class JerichoEnv:
    def __init__(self, rom_path, seed, vocab_rev, step_limit=None):
        self.rom_path = rom_path
        self.bindings = load_bindings(rom_path)
        self.act_gen = TemplateActionGenerator(self.bindings)
        self.seed = seed
        self.steps = 0
        self.step_limit = step_limit
        self.vocab_rev = vocab_rev
        self.env = None
        self.conn = None

    def create(self):
        self.env = FrotzEnv(self.rom_path, self.seed)
        start_redis()
        self.conn = redis.Redis(host='localhost', port=6379, db=0)
        self.conn.flushdb()

    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        action_valid = done or self.env.world_changed()
        info['action_valid'] = action_valid
        if not action_valid:  # Exit early for invalid actions
            return None, None, None, info
        if action_valid:
            self.steps += 1
        # Initialize with default values
        look = 'unknown'
        inv = 'unknown'
        info['valid'] = []
        if not done:
            try:
                save = self.env.save_str()
                look, _, _, _ = self.env.step('look')
                self.env.load_str(save)
                inv, _, _, _ = self.env.step('inventory')
                self.env.load_str(save)
                # Find Valid actions
                world_state_hash = self.env.get_world_state_hash()
                valid = self.conn.get(world_state_hash)
                if valid is None:
                    objs = [
                        o[0] for o in self.env.identify_interactive_objects(ob)
                    ]
                    obj_ids = [
                        self.vocab_rev[o[:self.bindings['max_word_length']]]
                        for o in objs
                    ]
                    acts = self.act_gen.generate_template_actions(
                        objs, obj_ids)
                    valid = self.env.find_valid_actions(acts)
                    redis_valid_value = '/'.join([str(a) for a in valid])
                    self.conn.set(world_state_hash, redis_valid_value)
                else:
                    valid = valid.decode('cp1252')
                    if valid:
                        valid = [eval(a) for a in valid.split('/')]
                info['valid'] = valid
            except RuntimeError:
                print('RuntimeError: {}, Done: {}, Info: {}'.format(
                    clean(ob), done, info))
        if self.step_limit and self.steps >= self.step_limit:
            done = True
        ob = look + '|' + inv + '|' + ob + '|' + action
        return ob, reward, done, info

    def reset(self):
        initial_ob, info = self.env.reset()
        try:
            save = self.env.save_str()
            look, _, _, _ = self.env.step('look')
            self.env.load_str(save)
            inv, _, _, _ = self.env.step('inventory')
            self.env.load_str(save)
        except RuntimeError:
            print('RuntimeError: {}, Info: {}'.format(initial_ob, info))
            look, inv = ''
        self.steps = 0
        initial_ob = look + '|' + inv + '|' + initial_ob + '|' + 'look'
        return initial_ob, info

    def get_dictionary(self):
        if not self.env:
            self.create()
        return self.env.get_dictionary()

    def get_action_set(self):
        return None

    def close(self):
        self.env.close()
        self.conn.shutdown(save=True)