def __init__(self, featurizer=F.Progress(), partially_observable=False, max_iter=1000, max_placement=2, max_name=2, max_inv=10, max_wiki=80, max_task=40, time_penalty=-.02, shuffle_wiki=False): self.world = W.World() self.engine = E.Engine() self.partially_observable = partially_observable self.history = [] self.iter = 0 self.max_iter = max_iter self.max_placement = max_placement self.max_name = max_name self.max_inv = max_inv self.max_wiki = max_wiki self.max_task = max_task self.time_penalty = time_penalty self.shuffle_wiki = shuffle_wiki self.renderer = F.Terminal() self.featurizer = featurizer self.agent = M.QueuedAgent() # action space self.action_space = M.QueuedAgent.valid_moves # observation shapes self.observation_space = self.featurizer.get_observation_space(self) self.vocab = Vocab(['pad', 'eos', '']) self.build_vocab() self.reset()
def test(flags, num_eps: int = 1000): from rtfm import featurizer as X gym_env = Net.create_env(flags) if flags.mode == 'test_render': gym_env.featurizer = X.Concat([gym_env.featurizer, X.Terminal()]) env = environment.Environment(gym_env) if not flags.random_agent: model = Net.make(flags, gym_env) model.eval() if flags.xpid is None: checkpointpath = './results_latest/model.tar' else: checkpointpath = os.path.expandvars( os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) checkpoint = torch.load(checkpointpath, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) observation = env.initial() returns = [] won = [] entropy = [] ep_len = [] while len(won) < num_eps: done = False steps = 0 while not done: if flags.random_agent: action = torch.zeros(1, 1, dtype=torch.int32) action[0][0] = random.randint(0, gym_env.action_space.n - 1) observation = env.step(action) else: agent_outputs = model(observation) observation = env.step(agent_outputs['action']) policy = F.softmax(agent_outputs['policy_logits'], dim=-1) log_policy = F.log_softmax(agent_outputs['policy_logits'], dim=-1) e = -torch.sum(policy * log_policy, dim=-1) entropy.append(e.mean(0).item()) steps += 1 done = observation['done'].item() if observation['done'].item(): returns.append(observation['episode_return'].item()) won.append(observation['reward'][0][0].item() > 0.5) ep_len.append(steps) # logging.info('Episode ended after %d steps. Return: %.1f', # observation['episode_step'].item(), # observation['episode_return'].item()) if flags.mode == 'test_render': sleep_seconds = os.environ.get('DELAY', '0.3') time.sleep(float(sleep_seconds)) if observation['done'].item(): print('Done: {}'.format('You won!!' if won[-1] else 'You lost!!')) print('Episode steps: {}'.format(observation['episode_step'])) print('Episode return: {}'.format(observation['episode_return'])) done_seconds = os.environ.get('DONE', None) if done_seconds is None: print('Press Enter to continue') input() else: time.sleep(float(done_seconds)) env.close() logging.info('Average returns over %i episodes: %.2f. Win rate: %.2f. Entropy: %.2f. Len: %.2f', num_eps, sum(returns)/len(returns), sum(won)/len(returns), sum(entropy)/max(1, len(entropy)), sum(ep_len)/len(ep_len))
from rtfm.dynamics.monster import Player from rtfm import featurizer as F if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment to use', default='rock_paper_scissors') parser.add_argument('--height', type=int, help='room size', default=12) parser.add_argument('--width', type=int, help='room size', default=22) parser.add_argument('--seed', type=int, help='seed', default=42) parser.add_argument('-p', '--partially_observable', action='store_true', help='only show partial observability') parser.add_argument('-c', '--control', action='store_true', help='assume direct control') parser.add_argument('-w', '--shuffle_wiki', action='store_true', help='shuffle facts in the wiki') args = parser.parse_args() featurizer = F.Concat([F.Progress(), F.ValidMoves(), F.Terminal()]) env = gym.make('rtfm:{}-v0'.format(args.env), featurizer=featurizer, partially_observable=args.partially_observable, room_shape=(args.height, args.width), shuffle_wiki=args.shuffle_wiki) random.seed(args.seed) score = total = 0 while True: env.reset() feat = reward = finished = won = None while not finished: if args.control: ch = None while ch not in Player.keymap: print('Current score {} out of {}'.format(score, total)) print('Enter your command. x to quit.') ch = getch.getch()