def create_env(cls, flags, featurizer=None): f = featurizer or X.Concat([X.Text(), X.ValidMoves()]) print('loading env') start_time = time.time() env = gym.make(flags.env, room_shape=(flags.height, flags.width), partially_observable=flags.partial_observability, max_placement=flags.max_placement, featurizer=f, shuffle_wiki=flags.shuffle_wiki, time_penalty=flags.time_penalty) print('loaded env in {} seconds'.format(time.time() - start_time)) return env
def test(flags, num_eps: int = 1000): from rtfm import featurizer as X gym_env = Net.create_env(flags) if flags.mode == 'test_render': gym_env.featurizer = X.Concat([gym_env.featurizer, X.Terminal()]) env = environment.Environment(gym_env) if not flags.random_agent: model = Net.make(flags, gym_env) model.eval() if flags.xpid is None: checkpointpath = './results_latest/model.tar' else: checkpointpath = os.path.expandvars( os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) checkpoint = torch.load(checkpointpath, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) observation = env.initial() returns = [] won = [] entropy = [] ep_len = [] while len(won) < num_eps: done = False steps = 0 while not done: if flags.random_agent: action = torch.zeros(1, 1, dtype=torch.int32) action[0][0] = random.randint(0, gym_env.action_space.n - 1) observation = env.step(action) else: agent_outputs = model(observation) observation = env.step(agent_outputs['action']) policy = F.softmax(agent_outputs['policy_logits'], dim=-1) log_policy = F.log_softmax(agent_outputs['policy_logits'], dim=-1) e = -torch.sum(policy * log_policy, dim=-1) entropy.append(e.mean(0).item()) steps += 1 done = observation['done'].item() if observation['done'].item(): returns.append(observation['episode_return'].item()) won.append(observation['reward'][0][0].item() > 0.5) ep_len.append(steps) # logging.info('Episode ended after %d steps. Return: %.1f', # observation['episode_step'].item(), # observation['episode_return'].item()) if flags.mode == 'test_render': sleep_seconds = os.environ.get('DELAY', '0.3') time.sleep(float(sleep_seconds)) if observation['done'].item(): print('Done: {}'.format('You won!!' if won[-1] else 'You lost!!')) print('Episode steps: {}'.format(observation['episode_step'])) print('Episode return: {}'.format(observation['episode_return'])) done_seconds = os.environ.get('DONE', None) if done_seconds is None: print('Press Enter to continue') input() else: time.sleep(float(done_seconds)) env.close() logging.info('Average returns over %i episodes: %.2f. Win rate: %.2f. Entropy: %.2f. Len: %.2f', num_eps, sum(returns)/len(returns), sum(won)/len(returns), sum(entropy)/max(1, len(entropy)), sum(ep_len)/len(ep_len))
def create_env(cls, flags, featurizer=None): return super().create_env(flags, featurizer=featurizer or X.Concat([X.Text(), X.ValidMoves(), X.RelativePosition()]))
from rtfm.dynamics.monster import Player from rtfm import featurizer as F if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment to use', default='rock_paper_scissors') parser.add_argument('--height', type=int, help='room size', default=12) parser.add_argument('--width', type=int, help='room size', default=22) parser.add_argument('--seed', type=int, help='seed', default=42) parser.add_argument('-p', '--partially_observable', action='store_true', help='only show partial observability') parser.add_argument('-c', '--control', action='store_true', help='assume direct control') parser.add_argument('-w', '--shuffle_wiki', action='store_true', help='shuffle facts in the wiki') args = parser.parse_args() featurizer = F.Concat([F.Progress(), F.ValidMoves(), F.Terminal()]) env = gym.make('rtfm:{}-v0'.format(args.env), featurizer=featurizer, partially_observable=args.partially_observable, room_shape=(args.height, args.width), shuffle_wiki=args.shuffle_wiki) random.seed(args.seed) score = total = 0 while True: env.reset() feat = reward = finished = won = None while not finished: if args.control: ch = None while ch not in Player.keymap: print('Current score {} out of {}'.format(score, total)) print('Enter your command. x to quit.') ch = getch.getch()