예제 #1
0
def launch(args, defaults, description):
    """
    Execute a complete training run.
    """

    logging.basicConfig(level=logging.INFO)
    parameters = process_args(args, defaults, description)

    if parameters.rom.endswith('.bin'):
        rom = parameters.rom
    else:
        rom = "%s.bin" % parameters.rom
    full_rom_path = os.path.join(defaults.BASE_ROM_PATH, rom)

    if parameters.deterministic:
        rng = np.random.RandomState(123456)
    else:
        rng = np.random.RandomState()

    ale = ale_python_interface.ALEInterface()
    ale.setInt('random_seed', rng.randint(1000))

    # FOR VISUALIZATION
    USE_SDL = False
    if parameters.display_screen:
        if USE_SDL:
            import sys
            if sys.platform == 'darwin':
                import pygame
                pygame.init()
                ale.setBool('sound', False)  # Sound doesn't work on OSX

    ale.setBool('display_screen', parameters.display_screen)
    ale.setFloat('repeat_action_probability',
                 parameters.repeat_action_probability)

    ale.loadROM(full_rom_path)

    num_actions = len(ale.getMinimalActionSet())

    agent = None

    if parameters.use_episodic_control:
        if parameters.qec_table is None:
            qec_table = EC_functions.QECTable(
                parameters.knn, parameters.state_dimension,
                parameters.projection_type,
                defaults.RESIZED_WIDTH * defaults.RESIZED_HEIGHT,
                parameters.buffer_size, num_actions, rng)
        else:
            handle = open(parameters.qec_table, 'r')
            qec_table = cPickle.load(handle)

            #If this doesnt work load using the function below
            # def try_to_load_as_pickled_object_or_None(filepath):
            #     """
            #     This is a defensive way to write pickle.load, allowing for very large files on all platforms
            #     """
            #     max_bytes = 2 ** 31 - 1
            #     try:
            #         input_size = os.path.getsize(filepath)
            #         bytes_in = bytearray(0)
            #         with open(filepath, 'rb') as f_in:
            #             for _ in range(0, input_size, max_bytes):
            #                 bytes_in += f_in.read(max_bytes)
            #         obj = cPickle.loads(bytes_in)
            #     except:
            #         return None
            #     return obj
            # qec_table = try_to_load_as_pickled_object_or_None(handle)

        agent = IBL_agent.EpisodicControl(qec_table, parameters.ec_discount,
                                          num_actions,
                                          parameters.epsilon_start,
                                          parameters.epsilon_min,
                                          parameters.epsilon_decay,
                                          parameters.experiment_prefix, rng)

    experiment = ale_experiment.ALEExperiment(
        ale, agent, defaults.RESIZED_WIDTH, defaults.RESIZED_HEIGHT,
        parameters.resize_method, parameters.epochs,
        parameters.steps_per_epoch, parameters.steps_per_test,
        parameters.frame_skip, parameters.death_ends_episode,
        parameters.max_start_nullops, rng)

    experiment.run()
예제 #2
0
def launch(args, defaults, description):
    """
    Execute a complete training run.
    """

    logging.basicConfig(level=logging.INFO)
    parameters = process_args(args, defaults, description)

    if parameters.rom.endswith('.bin'):
        rom = parameters.rom
    else:
        rom = "%s.bin" % parameters.rom
    full_rom_path = os.path.join(defaults.BASE_ROM_PATH, rom)

    if parameters.deterministic:
        rng = np.random.RandomState(123456)
    else:
        rng = np.random.RandomState()

    ale = ale_python_interface.ALEInterface()
    ale.setInt('random_seed', rng.randint(1000))

    if parameters.display_screen:
        import sys
        if sys.platform == 'darwin':
            import pygame
            pygame.init()
            ale.setBool('sound', False)  # Sound doesn't work on OSX

    ale.setBool('display_screen', parameters.display_screen)
    ale.setFloat('repeat_action_probability',
                 parameters.repeat_action_probability)

    ale.loadROM(full_rom_path)

    num_actions = len(ale.getMinimalActionSet())

    agent = None

    if parameters.use_episodic_control:
        if parameters.qec_table is None:
            qec_table = EC_functions.QECTable(
                parameters.knn, parameters.state_dimension,
                parameters.projection_type,
                defaults.RESIZED_WIDTH * defaults.RESIZED_HEIGHT,
                parameters.buffer_size, num_actions, rng)
        else:
            handle = open(parameters.qec_table, 'r')
            qec_table = pickle.load(handle)

        agent = EC_agent.EpisodicControl(qec_table, parameters.ec_discount,
                                         num_actions, parameters.epsilon_start,
                                         parameters.epsilon_min,
                                         parameters.epsilon_decay,
                                         parameters.experiment_prefix, rng)

    experiment = ale_experiment.ALEExperiment(
        ale, agent, defaults.RESIZED_WIDTH, defaults.RESIZED_HEIGHT,
        parameters.resize_method, parameters.epochs,
        parameters.steps_per_epoch, parameters.steps_per_test,
        parameters.frame_skip, parameters.death_ends_episode,
        parameters.max_start_nullops, rng)

    experiment.run()
def launch(args, defaults, description):
    """
    Execute a complete training run.
    """

    logging.basicConfig(level=logging.INFO)
    parameters = process_args(args, defaults, description)

    if parameters.rom.endswith('.bin'):
        rom = parameters.rom
    else:
        rom = "%s.bin" % parameters.rom
    full_rom_path = os.path.join(defaults.BASE_ROM_PATH, rom)

    if parameters.deterministic:
        rng = np.random.RandomState(123456)
    else:
        rng = np.random.RandomState()

    if parameters.cudnn_deterministic:
        theano.config.dnn.conv.algo_bwd = 'deterministic'

    ale = ale_python_interface.ALEInterface()
    ale.setInt('random_seed', rng.randint(1000))

    if parameters.display_screen:
        import sys
        if sys.platform == 'darwin':
            import pygame
            pygame.init()
            ale.setBool('sound', False) # Sound doesn't work on OSX

    ale.setBool('display_screen', parameters.display_screen)
    ale.setFloat('repeat_action_probability',
                 parameters.repeat_action_probability)

    ale.loadROM(full_rom_path)

    num_actions = len(ale.getMinimalActionSet())

    agent = None

    if parameters.method == 'ec_dqn':
        if parameters.nn_file is None:
            network = q_network.DeepQLearner(defaults.RESIZED_WIDTH,
                                             defaults.RESIZED_HEIGHT,
                                             num_actions,
                                             parameters.phi_length,
                                             parameters.discount,
                                             parameters.learning_rate,
                                             parameters.rms_decay,
                                             parameters.rms_epsilon,
                                             parameters.momentum,
                                             parameters.clip_delta,
                                             parameters.freeze_interval,
                                             parameters.batch_size,
                                             parameters.network_type,
                                             parameters.update_rule,
                                             parameters.batch_accumulator,
                                             rng, use_ec=True, double=parameters.double_dqn)
        else:
            handle = open(parameters.nn_file, 'r')
            network = cPickle.load(handle)

        if parameters.qec_table is None:
            qec_table = EC_functions.QECTable(parameters.knn,
                                              parameters.state_dimension,
                                              parameters.projection_type,
                                              defaults.RESIZED_WIDTH*defaults.RESIZED_HEIGHT,
                                              parameters.buffer_size,
                                              num_actions,
                                              rng,
                                              parameters.rebuild_knn_frequency)
        else:
            handle = open(parameters.qec_table, 'r')
            qec_table = cPickle.load(handle)

        agent = ale_agents.EC_DQN(network,
                                  qec_table,
                                  parameters.epsilon_start,
                                  parameters.epsilon_min,
                                  parameters.epsilon_decay,
                                  parameters.replay_memory_size,
                                  parameters.experiment_prefix,
                                  parameters.replay_start_size,
                                  parameters.update_frequency,
                                  parameters.ec_discount,
                                  num_actions,
                                  parameters.ec_testing,
                                  rng)

    if parameters.method == 'dqn_episodic_memory1':
        if parameters.nn_file is None:
            network = q_network.DeepQLearner(defaults.RESIZED_WIDTH,
                                             defaults.RESIZED_HEIGHT,
                                             num_actions,
                                             parameters.phi_length,
                                             parameters.discount,
                                             parameters.learning_rate,
                                             parameters.rms_decay,
                                             parameters.rms_epsilon,
                                             parameters.momentum,
                                             parameters.clip_delta,
                                             parameters.freeze_interval,
                                             parameters.batch_size,
                                             parameters.network_type,
                                             parameters.update_rule,
                                             parameters.batch_accumulator,
                                             rng, use_episodic_mem=True, double=parameters.double_dqn)
        else:
            handle = open(parameters.nn_file, 'r')
            network = cPickle.load(handle)

        if parameters.qec_table is None:
            qec_table = EC_functions.QECTable(parameters.knn,
                                              parameters.state_dimension,
                                              parameters.projection_type,
                                              defaults.RESIZED_WIDTH*defaults.RESIZED_HEIGHT,
                                              parameters.buffer_size,
                                              num_actions,
                                              rng,
                                              parameters.rebuild_knn_frequency)
        else:
            handle = open(parameters.qec_table, 'r')
            qec_table = cPickle.load(handle)

        agent = ale_agents.NeuralNetworkEpisodicMemory1(network,
                                                        qec_table,
                                                        parameters.epsilon_start,
                                                        parameters.epsilon_min,
                                                        parameters.epsilon_decay,
                                                        parameters.replay_memory_size,
                                                        parameters.experiment_prefix,
                                                        parameters.replay_start_size,
                                                        parameters.update_frequency,
                                                        parameters.ec_discount,
                                                        num_actions,
                                                        parameters.ec_testing,
                                                        rng)
    if parameters.method == 'dqn_episodic_memory2':
        if parameters.nn_file is None:
            network = q_network.DeepQLearner(defaults.RESIZED_WIDTH,
                                             defaults.RESIZED_HEIGHT,
                                             num_actions,
                                             parameters.phi_length,
                                             parameters.discount,
                                             parameters.learning_rate,
                                             parameters.rms_decay,
                                             parameters.rms_epsilon,
                                             parameters.momentum,
                                             parameters.clip_delta,
                                             parameters.freeze_interval,
                                             parameters.batch_size,
                                             parameters.network_type,
                                             parameters.update_rule,
                                             parameters.batch_accumulator,
                                             rng, use_episodic_mem=True, double=parameters.double_dqn)
        else:
            handle = open(parameters.nn_file, 'r')
            network = cPickle.load(handle)

        if parameters.qec_table is None:
            qec_table = EC_functions.QECTable(parameters.knn,
                                              parameters.state_dimension,
                                              parameters.projection_type,
                                              defaults.RESIZED_WIDTH*defaults.RESIZED_HEIGHT,
                                              parameters.buffer_size,
                                              num_actions,
                                              rng,
                                              parameters.rebuild_knn_frequency)
        else:
            handle = open(parameters.qec_table, 'r')
            qec_table = cPickle.load(handle)

    if parameters.method == 'dqn_episodic_memory3':
        if parameters.nn_file is None:
            network = q_network.DeepQLearner(defaults.RESIZED_WIDTH,
                                             defaults.RESIZED_HEIGHT,
                                             num_actions,
                                             parameters.phi_length,
                                             parameters.discount,
                                             parameters.learning_rate,
                                             parameters.rms_decay,
                                             parameters.rms_epsilon,
                                             parameters.momentum,
                                             parameters.clip_delta,
                                             parameters.freeze_interval,
                                             parameters.batch_size,
                                             parameters.network_type,
                                             parameters.update_rule,
                                             parameters.batch_accumulator,
                                             rng, use_episodic_mem=True, double=parameters.double_dqn)
        else:
            handle = open(parameters.nn_file, 'r')
            network = cPickle.load(handle)

        if parameters.qec_table is None:
            qec_table = EC_functions.LshHash(parameters.state_dimension,
                                             defaults.RESIZED_WIDTH*defaults.RESIZED_HEIGHT,
                                             parameters.buffer_size,
                                             rng)
        else:
            handle = open(parameters.qec_table, 'r')
            qec_table = cPickle.load(handle)

        agent = ale_agents.NeuralNetworkEpisodicMemory3(network,
                                                        qec_table,
                                                        parameters.epsilon_start,
                                                        parameters.epsilon_min,
                                                        parameters.epsilon_decay,
                                                        parameters.replay_memory_size,
                                                        parameters.experiment_prefix,
                                                        parameters.replay_start_size,
                                                        parameters.update_frequency,
                                                        parameters.ec_discount,
                                                        num_actions,
                                                        parameters.ec_testing,
                                                        rng)

    if parameters.method == 'dqn':
        if parameters.nn_file is None:
            network = q_network.DeepQLearner(defaults.RESIZED_WIDTH,
                                             defaults.RESIZED_HEIGHT,
                                             num_actions,
                                             parameters.phi_length,
                                             parameters.discount,
                                             parameters.learning_rate,
                                             parameters.rms_decay,
                                             parameters.rms_epsilon,
                                             parameters.momentum,
                                             parameters.clip_delta,
                                             parameters.freeze_interval,
                                             parameters.batch_size,
                                             parameters.network_type,
                                             parameters.update_rule,
                                             parameters.batch_accumulator,
                                             rng, double=parameters.double_dqn)
        else:
            handle = open(parameters.nn_file, 'r')
            network = cPickle.load(handle)

        agent = ale_agents.NeuralAgent(network,
                                       parameters.epsilon_start,
                                       parameters.epsilon_min,
                                       parameters.epsilon_decay,
                                       parameters.replay_memory_size,
                                       parameters.experiment_prefix,
                                       parameters.replay_start_size,
                                       parameters.update_frequency,
                                       rng)

    if parameters.method == 'episodic_control':
            if parameters.qec_table is None:
                qec_table = EC_functions.QECTable(parameters.knn,
                                                  parameters.state_dimension,
                                                  parameters.projection_type,
                                                  defaults.RESIZED_WIDTH*defaults.RESIZED_HEIGHT,
                                                  parameters.buffer_size,
                                                  num_actions,
                                                  rng,
                                                  parameters.rebuild_knn_frequency)
            else:
                handle = open(parameters.qec_table, 'r')
                qec_table = cPickle.load(handle)

            agent = ale_agents.EpisodicControl(qec_table,
                                               parameters.ec_discount,
                                               num_actions,
                                               parameters.epsilon_start,
                                               parameters.epsilon_min,
                                               parameters.epsilon_decay,
                                               parameters.experiment_prefix,
                                               parameters.ec_testing,
                                               rng)

    experiment = ale_experiment.ALEExperiment(ale, agent,
                                              defaults.RESIZED_WIDTH,
                                              defaults.RESIZED_HEIGHT,
                                              parameters.resize_method,
                                              parameters.epochs,
                                              parameters.steps_per_epoch,
                                              parameters.steps_per_test,
                                              parameters.frame_skip,
                                              parameters.death_ends_episode,
                                              parameters.max_start_nullops,
                                              rng)

    experiment.run()