Exemplo n.º 1
0
 def create(self):
     ''' Create the Jericho environment and connect to redis. '''
     self.env = jericho.FrotzEnv(self.rom_path, self.seed)
     self.bindings = jericho.load_bindings(self.rom_path)
     self.act_gen = TemplateActionGenerator(self.bindings)
     self.max_word_len = self.bindings['max_word_length']
     self.vocab, self.vocab_rev = load_vocab(self.env)
     self.conn_valid = redis.Redis(host='localhost', port=6379, db=0)
     self.conn_openie = redis.Redis(host='localhost', port=6379, db=1)
Exemplo n.º 2
0
def test_valid_action_identification():
    rom_path = pjoin(DATA_PATH, '905.z5')
    env = jericho.FrotzEnv(rom_path)
    bindings = jericho.load_bindings(rom_path)
    act_gen = TemplateActionGenerator(bindings)
    obs, info = env.reset()
    # interactive_objs = [obj[0] for obj in env.identify_interactive_objects(use_object_tree=True)]
    interactive_objs = ['phone', 'keys', 'wallet']
    candidate_actions = act_gen.generate_actions(interactive_objs)

    valid = env.find_valid_actions(candidate_actions)
    assert 'take wallet' in valid
    assert 'open wallet' in valid
    assert 'take keys' in valid
    assert 'get up' in valid
    assert 'take phone' in valid
Exemplo n.º 3
0
    def __init__(self, args):
        configure_logger(args.output_dir)
        log(args)
        self.args = args

        self.log_freq = args.log_freq
        self.update_freq = args.update_freq_td
        self.update_freq_tar = args.update_freq_tar

        self.filename = 'ptdqn' + args.rom_path + str(args.run_number)
        wandb.init(project="my-project", name=self.filename)

        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(args.spm_path)
        self.binding = jericho.load_bindings(args.rom_path)
        self.vocab_act, self.vocab_act_rev = self.load_vocab_act(args.rom_path)
        vocab_size = len(self.sp)
        vocab_size_act = len(self.vocab_act.keys())

        self.template_generator = TemplateActionGenerator(self.binding)
        self.template_size = len(self.template_generator.templates)

        if args.replay_buffer_type == 'priority':
            self.replay_buffer = PriorityReplayBuffer(
                int(args.replay_buffer_size))
        elif args.replay_buffer_type == 'standard':
            self.replay_buffer = ReplayBuffer(int(args.replay_buffer_size))
        self.action_size = self.template_size
        self.action_parameter_size = vocab_size_act
        self.model = TDQN(args, self.action_size, self.action_parameter_size,
                          self.template_size, vocab_size,
                          vocab_size_act).cuda()
        self.target_model = TDQN(args, self.action_size,
                                 self.action_parameter_size,
                                 self.template_size, vocab_size,
                                 vocab_size_act).cuda()

        self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)

        self.num_steps = args.steps
        self.batch_size = args.batch_size
        self.gamma = args.gamma

        self.rho = args.rho
        self.vocab_size_act = vocab_size_act

        self.bce_loss = nn.BCELoss()
Exemplo n.º 4
0
def test_copy():
    rom = pjoin(DATA_PATH, "905.z5")
    bindings = jericho.load_bindings(rom)
    env = jericho.FrotzEnv(rom, seed=bindings['seed'])
    env.reset()

    walkthrough = bindings['walkthrough'].split('/')
    expected = [env.step(act) for act in walkthrough]

    env.reset()
    for i, act in enumerate(walkthrough):
        obs, rew, done, info = env.step(act)

        if i + 1 < len(walkthrough):
            fork = env.copy()
            for j, cmd in enumerate(walkthrough[i+1:], start=i+1):
                obs, rew, done, info = fork.step(cmd)
                assert (obs, rew, done, info) == expected[j]
Exemplo n.º 5
0
    def __init__(self, args):
        self.args = args

        self.log_freq = args.log_freq
        self.update_freq = args.update_freq_td
        self.update_freq_tar = args.update_freq_tar
        self.filename = 'random' + args.rom_path + str(args.run_number)
        wandb.init(project="my-project", name=self.filename)
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(args.spm_path)
        self.binding = jericho.load_bindings(args.rom_path)
        self.vocab_act, self.vocab_act_rev = self.load_vocab_act(args.rom_path)
        vocab_size = len(self.sp)
        self.vocab_size_act = len(self.vocab_act.keys())

        self.template_generator = TemplateActionGenerator(self.binding)
        self.template_size = len(self.template_generator.templates)

        self.num_steps = args.steps
Exemplo n.º 6
0
 def test_load_bindings(self):
     self.assertRaises(ValueError, jericho.load_bindings, "")
     data1 = jericho.load_bindings("905")
     data2 = jericho.load_bindings("905.z5")
     assert data1 == data2
Exemplo n.º 7
0
        for obj in env.identify_interactive_objects(use_object_tree=True)
    ]
    #ex. ['mailbox', 'boarded', 'white']
    candidate_actions = act_gen.generate_actions(interactive_objs)
    #ex. ['drive boarded', 'swim in mailbox', 'jump white', 'kick boarded','pour white in boarded', ... ]
    valid_actions = env.find_valid_actions(candidate_actions)
    chosen_action = random.choice(valid_actions)
    if args.verbose:
        print(colored(("valid actions:", valid_actions), 'green'))
        print(colored(("chosen action:", chosen_action), 'red'))
    return chosen_action


args = parse_args()
env = FrotzEnv(args.rom_path, seed=12)
bindings = load_bindings(args.rom_path)
act_gen = TemplateActionGenerator(bindings)
obs, info = env.reset()
done = False

while not done:
    if args.debug:
        import ipdb
        ipdb.set_trace()
    args.max_iterations -= 1
    if (args.max_iterations < 1):
        print(colored('Max Iterations Exceeded -- BREAK', 'red'))
        break

    # Use TemplateActionGenerator to select an action
    chosen_action = get_random_action(env, act_gen, args)
Exemplo n.º 8
0
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("filename",
                        help="Path to a Z-Machine game.")
    parser.add_argument("--walkthrough",
                        help="External walkthrough (one command per line). Default: use Jericho's one, if it exists.")
    parser.add_argument("--skip-to", type=int, default=0,
                        help="Auto-play walkthrough until the nth command before dropping into interactive mode.")

    return parser.parse_args()

args = parse_args()

bindings = jericho.load_bindings(args.filename)
env = jericho.FrotzEnv(args.filename, seed=bindings['seed'])

history = []
obs, info = env.reset()

history.append(env.get_state())

STEP_BY_STEP_WALKTRHOUGH = True

walkthrough = bindings.get('walkthrough', '').split('/')
if args.walkthrough:
    walkthrough = []
    for line in open(args.walkthrough):
        cmd = line.split("#")[0].strip()
        if cmd:
Exemplo n.º 9
0
    parser.add_argument("filenames", nargs="+",
                        help="Path to a Z-Machine game(s).")
    parser.add_argument("--debug", action="store_true",
                        help="Launch ipdb on FAIL.")
    parser.add_argument("-v", "--verbose", action="store_true",
                        help="Print the last observation when not achieving max score.")
    return parser.parse_args()

args = parse_args()

filename_max_length = max(map(len, args.filenames))
for filename in sorted(args.filenames):
    print(filename.ljust(filename_max_length), end=" ")
    try:
        bindings = jericho.load_bindings(filename)
    except ValueError:
        print(colored("SKIP\tUnsupported game", 'yellow'))
        continue

    if "walkthrough" not in bindings:
        print(colored("SKIP\tMissing walkthrough", 'yellow'))
        continue

    env = jericho.FrotzEnv(filename, seed=bindings['seed'])
    env.reset()

    walkthrough = bindings['walkthrough'].split('/')
    for cmd in walkthrough:
        obs, rew, done, info = env.step(cmd)