def test_flatten_and_unflatten(self): embed_game = embed.make_game_embedding() embed_game_struct = embed_game.map(lambda e: e) embed_game_flat = embed_game.flatten(embed_game_struct) embed_game_unflat = embed_game.unflatten(embed_game_flat) self.assertEqual(embed_game_unflat, embed_game_struct)
import tree import utils def get_experiment_directory(): # create directory for tf checkpoints and other experiment artifacts today = datetime.date.today() expt_tag = f'{today.year}-{today.month}-{today.day}_{secrets.token_hex(8)}' expt_dir = f'experiments/{expt_tag}' os.makedirs(expt_dir, exist_ok=True) return expt_dir # necessary because our dataset has some mismatching types, which ultimately # come from libmelee occasionally giving differently-typed data # Won't be necessary if we re-generate the dataset. embed_game = embed.make_game_embedding() def sanitize_game(game): """Casts inputs to the right dtype and discard unused inputs.""" gamestates, counts, rewards = game gamestates = embed_game.map(lambda e, a: a.astype(e.dtype), gamestates) return gamestates, counts, rewards def sanitize_batch(batch): game, restarting = batch game = sanitize_game(game) return game, restarting class TrainManager: def __init__(self, learner, data_source, step_kwargs={}):
import sonnet as snt import tensorflow as tf import embed import utils # don't use opponent's controller # our own will be exposed in the input embed_game = embed.make_game_embedding(player_config=dict( with_controller=False)) def process_inputs(inputs): gamestate, p1_controller_embed = inputs gamestate_embed = embed_game(gamestate) return tf.concat([gamestate_embed, p1_controller_embed], -1) class Network(snt.Module): def initial_state(self, batch_size): raise NotImplementedError() def step(self, inputs, prev_state): ''' Returns outputs and next recurrent state. inputs: (batch_size, x_dim) ''' raise NotImplementedError() def unroll(self, inputs, initial_state): return utils.dynamic_rnn(self.step, inputs, initial_state)
def main(saved_model_path, dolphin_path, iso_path, _log): embed_game = embed.make_game_embedding() policy = tf.saved_model.load(saved_model_path) sample = lambda *structs: policy.sample(*tf.nest.flatten(structs)) hidden_state = policy.initial_state(1) console = melee.Console(path=dolphin_path) # This isn't necessary, but makes it so that Dolphin will get killed when you ^C def signal_handler(sig, frame): console.stop() print("Shutting down cleanly...") # sys.exit(0) signal.signal(signal.SIGINT, signal_handler) controller = melee.Controller(console=console, port=1, type=melee.ControllerType.STANDARD) cpu_controller = melee.Controller(console=console, port=2, type=melee.ControllerType.STANDARD) # Run the console console.run(iso_path=iso_path) # Connect to the console _log.info("Connecting to console...") if not console.connect(): _log.error("Failed to connect to the console.") return _log.info("Console connected") for c in [controller, cpu_controller]: print("Connecting controller to console...") if not c.connect(): print("ERROR: Failed to connect the controller.") sys.exit(-1) print("Controller connected") action_repeat = 0 repeats_left = 0 # Main loop while True: # "step" to the next frame gamestate = console.step() if gamestate is None: continue if gamestate.frame == -123: # initial frame controller.release_all() # The console object keeps track of how long your bot is taking to process frames # And can warn you if it's taking too long if console.processingtime * 1000 > 12: print("WARNING: Last frame took " + str(console.processingtime * 1000) + "ms to process.") # What menu are we in? if gamestate.menu_state in [ melee.Menu.IN_GAME, melee.Menu.SUDDEN_DEATH ]: if repeats_left > 0: repeats_left -= 1 continue embedded_game = embed_game.from_state(gamestate), action_repeat batched_game = tf.nest.map_structure( lambda a: np.expand_dims(a, 0), embedded_game) sampled_controller_with_repeat, hidden_state = sample( batched_game, hidden_state) sampled_controller_with_repeat = tf.nest.map_structure( lambda t: np.squeeze(t.numpy(), 0), sampled_controller_with_repeat) sampled_controller = sampled_controller_with_repeat['controller'] action_repeat = sampled_controller_with_repeat['action_repeat'] repeats_left = action_repeat for b in embed.LEGAL_BUTTONS: if sampled_controller['button'][b.value]: controller.press_button(b) else: controller.release_button(b) main_stick = sampled_controller["main_stick"] controller.tilt_analog(melee.Button.BUTTON_MAIN, *main_stick) c_stick = sampled_controller["c_stick"] controller.tilt_analog(melee.Button.BUTTON_C, *c_stick) controller.press_shoulder(melee.Button.BUTTON_L, sampled_controller["l_shoulder"]) controller.press_shoulder(melee.Button.BUTTON_R, sampled_controller["r_shoulder"]) else: melee.MenuHelper.menu_helper_simple(gamestate, controller, melee.Character.FOX, melee.Stage.YOSHIS_STORY, connect_code=None, autostart=False, swag=False) melee.MenuHelper.menu_helper_simple(gamestate, cpu_controller, melee.Character.FOX, melee.Stage.YOSHIS_STORY, connect_code=None, cpu_level=9, autostart=True, swag=False)