Beispiel #1
0
def test_best_controller(current_time):
    current_time = str(current_time)
    games = GAMES
    levels = LEVELS
    result_queue = Queue()

    vae, lstm, best_controller, solver, checkpoint = init_models(current_time, sequence=1, load_vae=True, load_lstm=True, load_controller=True)
    game = games[0]
    level = levels[game][1]
    print("[CONTROLLER] Current level is: %s" % level)
    new_game = VAECGame(0, vae, lstm, best_controller, \
            game, level, result_queue)
    new_game.start()
    new_game.join()
def train_lstm(current_time):
    """
    Train the LSTM to be able to predict the next latent vector z given the current vector z
    and an action.
    """

    dataset = LSTMDataset()
    client = MongoClient()
    db = client.retro_contest
    collection = db[current_time]
    fs = gridfs.GridFS(db)

    last_id = 0
    lr = LR
    version = 1
    total_ite = 1

    ## Load or create models
    vae, lstm, _, _, checkpoint = init_models(current_time,
                                              load_vae=True,
                                              load_lstm=True,
                                              load_controller=False)
    if not checkpoint:
        optimizer = create_optimizer(lstm, lr)
        state = create_state(version, lr, total_ite, optimizer)
        save_checkpoint(lstm, "lstm", state, current_time)
    else:
        optimizer = create_optimizer(lstm, lr, param=checkpoint['optimizer'])
        total_ite = checkpoint['total_ite']
        lr = checkpoint['lr']
        version = checkpoint['version']

    ## Fill the dataset (or wait for the database to be filled)
    while len(dataset) * PLAYOUTS < SIZE:
        last_id = fetch_new_run(collection,
                                fs,
                                dataset,
                                last_id,
                                loaded_version=current_time)
        time.sleep(5)

    dataloader = DataLoader(dataset,
                            batch_size=BATCH_SIZE_LSTM,
                            collate_fn=collate_fn)
    while True:
        running_loss = []
        batch_loss = []

        for batch_idx, (frames, actions) in enumerate(dataloader):

            ## Save the model
            if total_ite % SAVE_TICK == 0:
                version += 1
                state = create_state(version, lr, total_ite, optimizer)
                save_checkpoint(lstm, "lstm", state, current_time)

            ## Save a picture of the long term sampling
            if total_ite % SAVE_PIC_TICK == 0:
                sample_long_term(vae, lstm, frames, version, total_ite)

            encoded = vae(frames, encode=True)
            example = {'encoded': encoded, 'actions': actions}

            loss = train_epoch(lstm, optimizer, example)
            running_loss.append(loss)

            ## Print running loss
            if total_ite % LOSS_TICK == 0:
                print("[TRAIN] current iteration: %d, averaged loss: %.3f"\
                        % (total_ite, loss))
                batch_loss.append(np.mean(running_loss))
                running_loss = []

            ## Fetch new games
            if total_ite % REFRESH_TICK == 0:
                new_last_id = fetch_new_run(collection, fs, dataset, last_id)
                if new_last_id == last_id:
                    last_id = 0
                else:
                    last_id = new_last_id

            total_ite += 1

        if len(batch_loss) > 0:
            print("[TRAIN] Average backward pass loss : %.3f, current lr: %f" %
                  (np.mean(batch_loss), lr))
def train_controller(current_time):
    """
    Train the controllers by using the CMA-ES algorithm to improve candidature solutions
    by testing them in parallel using multiprocessing
    """

    current_time = str(current_time)
    number_generations = 1
    games = GAMES
    levels = LEVELS
    current_game = False
    result_queue = Queue()

    vae, lstm, best_controller, solver, checkpoint = init_models(
        current_time,
        sequence=1,
        load_vae=True,
        load_controller=True,
        load_lstm=True)
    if checkpoint:
        current_ctrl_version = checkpoint["version"]
        current_solver_version = checkpoint["solver_version"]
        new_results = solver.result()
        current_best = new_results[1]
    else:
        current_ctrl_version = 1
        current_solver_version = 1
        current_best = 0

    while True:
        solutions = solver.ask()
        fitlist = np.zeros(POPULATION)
        eval_left = 0

        ## Once a level is beaten, remove it from the training set of levels
        if current_best > SCORE_CAP or not current_game:
            if not current_game or len(levels[current_game]) == 0:
                current_game = games[0]
                games.remove(current_game)
                current_best = 0
            current_level = np.random.choice(levels[current_game])
            levels[current_game].remove(current_level)

        print("[CONTROLLER] Current game: %s and level is: %s" %
              (current_game, current_level))
        while eval_left < POPULATION:
            jobs = []
            todo = PARALLEL if eval_left + PARALLEL <= POPULATION else (
                eval_left + PARALLEL) % POPULATION

            ## Create the child processes to evaluate in parallel
            print("[CONTROLLER] Starting new batch")
            for job in range(todo):
                process_id = eval_left + job

                ## Assign new weights to the controller, given by the CMA
                controller = Controller(LATENT_VEC, PARAMS_FC1,
                                        ACTION_SPACE).to(DEVICE)
                init_controller(controller, solutions[process_id])

                ## Start the evaluation
                new_game = VAECGame(process_id, vae, lstm, controller,
                                    current_game, current_level, result_queue)
                new_game.start()
                jobs.append(new_game)

            ## Wait for the evaluation to be completed
            for p in jobs:
                p.join()

            eval_left = eval_left + todo
            print("[CONTROLLER] Done with batch")

        ## Get the results back from the processes
        times = create_results(result_queue, fitlist)

        ## For display
        current_score = np.max(fitlist)
        average_score = np.mean(fitlist)

        ## Update solver with results
        max_idx = np.argmax(fitlist)
        fitlist = rankmin(fitlist)
        solver.tell(fitlist)
        new_results = solver.result()

        ## Display
        print("[CONTROLLER] Total duration for generation: %.3f seconds, average duration:"
            " %.3f seconds per process, %.3f seconds per run" % ((np.sum(times), \
                    np.mean(times), np.mean(times) / REPEAT_ROLLOUT)))
        print("[CONTROLLER] Creating generation: {} ...".format(
            number_generations + 1))
        print("[CONTROLLER] Current best score: {}, new run best score: {}".
              format(current_best, current_score))
        print(
            "[CONTROLLER] Best score ever: {}, current number of improvements: {}"
            .format(current_best, current_ctrl_version))
        print(
            "[CONTROLLER] Average score on all of the processes: {}\n".format(
                average_score))

        ## Save the new best controller
        if current_score > current_best:
            init_controller(best_controller, solutions[max_idx])
            state = {
                'version': current_ctrl_version,
                'solver_version': current_solver_version,
                'score': current_score,
                'level': current_level,
                'game': current_game,
                'generation': number_generations
            }
            save_checkpoint(best_controller, "controller", state, current_time)
            current_ctrl_version += 1
            current_best = current_score

        ## Save solver and change level to a random one
        if number_generations % SAVE_SOLVER_TICK == 0:
            dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), \
                        'saved_models', current_time, "{}-solver.pkl".format(current_solver_version))
            pickle.dump(solver, open(dir_path, 'wb'))
            current_solver_version += 1
            current_level = np.random.choice(levels[current_game])

        number_generations += 1