def cli(): parser = argparse.ArgumentParser( description="Train or evaluate an RLlib Trainer.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=EXAMPLE_USAGE) subcommand_group = parser.add_subparsers( help="Commands to train or evaluate an RLlib agent.", dest="command") # see _SubParsersAction.add_parser in # https://github.com/python/cpython/blob/master/Lib/argparse.py train_parser = train.create_parser( lambda **kwargs: subcommand_group.add_parser("train", **kwargs)) evaluate_parser = evaluate.create_parser( lambda **kwargs: subcommand_group.add_parser("evaluate", **kwargs)) rollout_parser = evaluate.create_parser( lambda **kwargs: subcommand_group.add_parser("rollout", **kwargs)) options = parser.parse_args() if options.command == "train": train.run(options, train_parser) elif options.command == "evaluate": evaluate.run(options, evaluate_parser) elif options.command == "rollout": deprecation_warning(old="rllib rollout", new="rllib evaluate", error=False) evaluate.run(options, rollout_parser) else: parser.print_help()
def run(args, parser): config = args.config if 'callbacks' not in config: config['callbacks'] = {} if 'on_train_result' not in config['callbacks']: config['callbacks']['on_train_result'] = on_train_result else: print('on_train_result defined. Overriding default') train.run(args, parser)
def cli(): parser = argparse.ArgumentParser( description="Train or Run an RLlib Agent.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=EXAMPLE_USAGE) subcommand_group = parser.add_subparsers( help="Commands to train or run an RLlib agent.", dest="command") # see _SubParsersAction.add_parser in # https://github.com/python/cpython/blob/master/Lib/argparse.py train_parser = train.create_parser( lambda **kwargs: subcommand_group.add_parser("train", **kwargs)) rollout_parser = rollout.create_parser( lambda **kwargs: subcommand_group.add_parser("rollout", **kwargs)) options = parser.parse_args() if options.command == "train": train.run(options, train_parser) elif options.command == "rollout": rollout.run(options, rollout_parser) else: parser.print_help()
activation_fn=None, weights_initializer=weights_initializer) return net_out, net def create_pacman_environment(layout_name='originalClassic', stick_actions=False): """ """ layout = pacman_env.layout.getLayout(layout_name) if layout is None: raise ValueError('No suck layout as %s' % layout_name) ghosts = [] for i in range(2): ghosts.append(pacman_env.ghostAgents.RandomGhost(i + 1)) #display = VizGraphics(includeInfoPane=False, zoom=0.4) display = pacman_env.matrixDisplay.PacmanGraphics(layout) #teacherAgents = [LeftTurnAgent(), GreedyAgent()] env = pacman_env.PacmanEnv(layout, ghosts, display) #, teacherAgents=teacherAgents) return env if __name__ == '__main__': ModelCatalog.register_custom_model("PacmanModel", PacmanModel) register_env("pacman", lambda _: create_pacman_environment()) parser = create_parser() args = parser.parse_args() run(args, parser)
def main(): parser = create_parser() args = parser.parse_args() run(args, parser)
def main(): raylab.register_all_agents() raylab.register_all_environments() parser = create_parser() args = parser.parse_args() run(args, parser)