Exemplo n.º 1
0
    def test_flatten_and_unflatten(self):
        embed_game = embed.make_game_embedding()

        embed_game_struct = embed_game.map(lambda e: e)
        embed_game_flat = embed_game.flatten(embed_game_struct)
        embed_game_unflat = embed_game.unflatten(embed_game_flat)

        self.assertEqual(embed_game_unflat, embed_game_struct)
Exemplo n.º 2
0
import tree

import utils

def get_experiment_directory():
  # create directory for tf checkpoints and other experiment artifacts
  today = datetime.date.today()
  expt_tag = f'{today.year}-{today.month}-{today.day}_{secrets.token_hex(8)}'
  expt_dir = f'experiments/{expt_tag}'
  os.makedirs(expt_dir, exist_ok=True)
  return expt_dir

# necessary because our dataset has some mismatching types, which ultimately
# come from libmelee occasionally giving differently-typed data
# Won't be necessary if we re-generate the dataset.
embed_game = embed.make_game_embedding()

def sanitize_game(game):
  """Casts inputs to the right dtype and discard unused inputs."""
  gamestates, counts, rewards = game
  gamestates = embed_game.map(lambda e, a: a.astype(e.dtype), gamestates)
  return gamestates, counts, rewards

def sanitize_batch(batch):
  game, restarting = batch
  game = sanitize_game(game)
  return game, restarting

class TrainManager:

  def __init__(self, learner, data_source, step_kwargs={}):
Exemplo n.º 3
0
import sonnet as snt
import tensorflow as tf

import embed
import utils

# don't use opponent's controller
# our own will be exposed in the input
embed_game = embed.make_game_embedding(player_config=dict(
    with_controller=False))


def process_inputs(inputs):
    gamestate, p1_controller_embed = inputs
    gamestate_embed = embed_game(gamestate)
    return tf.concat([gamestate_embed, p1_controller_embed], -1)


class Network(snt.Module):
    def initial_state(self, batch_size):
        raise NotImplementedError()

    def step(self, inputs, prev_state):
        '''
      Returns outputs and next recurrent state.
      inputs: (batch_size, x_dim)
    '''
        raise NotImplementedError()

    def unroll(self, inputs, initial_state):
        return utils.dynamic_rnn(self.step, inputs, initial_state)
Exemplo n.º 4
0
def main(saved_model_path, dolphin_path, iso_path, _log):
    embed_game = embed.make_game_embedding()
    policy = tf.saved_model.load(saved_model_path)
    sample = lambda *structs: policy.sample(*tf.nest.flatten(structs))
    hidden_state = policy.initial_state(1)

    console = melee.Console(path=dolphin_path)

    # This isn't necessary, but makes it so that Dolphin will get killed when you ^C
    def signal_handler(sig, frame):
        console.stop()
        print("Shutting down cleanly...")
        # sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    controller = melee.Controller(console=console,
                                  port=1,
                                  type=melee.ControllerType.STANDARD)
    cpu_controller = melee.Controller(console=console,
                                      port=2,
                                      type=melee.ControllerType.STANDARD)

    # Run the console
    console.run(iso_path=iso_path)

    # Connect to the console
    _log.info("Connecting to console...")
    if not console.connect():
        _log.error("Failed to connect to the console.")
        return
    _log.info("Console connected")

    for c in [controller, cpu_controller]:
        print("Connecting controller to console...")
        if not c.connect():
            print("ERROR: Failed to connect the controller.")
            sys.exit(-1)
        print("Controller connected")

    action_repeat = 0
    repeats_left = 0

    # Main loop
    while True:
        # "step" to the next frame
        gamestate = console.step()
        if gamestate is None:
            continue

        if gamestate.frame == -123:  # initial frame
            controller.release_all()

        # The console object keeps track of how long your bot is taking to process frames
        #   And can warn you if it's taking too long
        if console.processingtime * 1000 > 12:
            print("WARNING: Last frame took " +
                  str(console.processingtime * 1000) + "ms to process.")

        # What menu are we in?
        if gamestate.menu_state in [
                melee.Menu.IN_GAME, melee.Menu.SUDDEN_DEATH
        ]:
            if repeats_left > 0:
                repeats_left -= 1
                continue

            embedded_game = embed_game.from_state(gamestate), action_repeat
            batched_game = tf.nest.map_structure(
                lambda a: np.expand_dims(a, 0), embedded_game)
            sampled_controller_with_repeat, hidden_state = sample(
                batched_game, hidden_state)
            sampled_controller_with_repeat = tf.nest.map_structure(
                lambda t: np.squeeze(t.numpy(), 0),
                sampled_controller_with_repeat)
            sampled_controller = sampled_controller_with_repeat['controller']
            action_repeat = sampled_controller_with_repeat['action_repeat']
            repeats_left = action_repeat

            for b in embed.LEGAL_BUTTONS:
                if sampled_controller['button'][b.value]:
                    controller.press_button(b)
                else:
                    controller.release_button(b)
            main_stick = sampled_controller["main_stick"]
            controller.tilt_analog(melee.Button.BUTTON_MAIN, *main_stick)
            c_stick = sampled_controller["c_stick"]
            controller.tilt_analog(melee.Button.BUTTON_C, *c_stick)
            controller.press_shoulder(melee.Button.BUTTON_L,
                                      sampled_controller["l_shoulder"])
            controller.press_shoulder(melee.Button.BUTTON_R,
                                      sampled_controller["r_shoulder"])
        else:
            melee.MenuHelper.menu_helper_simple(gamestate,
                                                controller,
                                                melee.Character.FOX,
                                                melee.Stage.YOSHIS_STORY,
                                                connect_code=None,
                                                autostart=False,
                                                swag=False)
            melee.MenuHelper.menu_helper_simple(gamestate,
                                                cpu_controller,
                                                melee.Character.FOX,
                                                melee.Stage.YOSHIS_STORY,
                                                connect_code=None,
                                                cpu_level=9,
                                                autostart=True,
                                                swag=False)