Ejemplo n.º 1
0
def _get_experiences_sync(settings: TrainSettings, executable: str, port_chooser: typing.Callable,
                          create_flags: int, aggressive: bool, spec: bool,
                          replaypath: str, settings_path: str, tar_num_ticks: int):
    session: SessionSettings = settings.current_session
    num_ticks_to_do = tar_num_ticks
    if os.path.exists(replaypath):
        replay = rb.FileReadableReplayBuffer(replaypath)
        num_ticks_to_do -= len(replay)
        replay.close()

    with open(settings_path, 'w') as outfile:
        json.dump({'teacher_force_amt': session.train_force_amount,
                   'replay_path': replaypath}, outfile)

    while num_ticks_to_do > 0:
        print(f'--starting game to get another {num_ticks_to_do} experiences--')
        sys.stdout.flush()
        secret1 = secrets.token_hex()
        secret2 = secrets.token_hex()
        port = port_chooser()
        procs = []
        procs.append(_start_server(executable, secret1, secret2, port, session.tie_len,
                                   aggressive, create_flags))
        if random.random() < 0.5:
            tmp = secret1
            secret1 = secret2
            secret2 = tmp
            del tmp

        time.sleep(2)

        procs.append(_start_bot(executable, settings.train_bot, secret1, port, create_flags,
                                aggressive, 'train_bot.log', ['--settings', settings_path]))
        procs.append(_start_bot(executable, settings.adver_bot, secret2, port, create_flags,
                                aggressive, 'adver_bot.log'))
        if spec:
            procs.append(_start_spec(executable, port, create_flags))

        for proc in procs:
            proc.wait()

        print('--finished game--')
        sys.stdout.flush()
        time.sleep(0.5)
        if not os.path.exists(replaypath):
            print('--game failed unexpectedly (no replay), waiting a bit and restarting--')
            sys.stdout.flush()
        else:
            replay = rb.FileReadableReplayBuffer(replaypath)
            num_ticks_to_do = tar_num_ticks - len(replay)
            replay.close()
            if num_ticks_to_do <= 0 and session.balance:
                rb.balance_experiences(
                    replaypath, **BALANCE_LOOKUP[session.balance_technique])
                replay = rb.FileReadableReplayBuffer(replaypath)
                num_ticks_to_do = tar_num_ticks - len(replay)
                replay.close()
        time.sleep(2)
Ejemplo n.º 2
0
def offline_learning():
    """Loads the replay buffer and trains on it."""
    perf_file = os.path.join(SAVEDIR, 'offline_learning_perf.log')
    perf = perf_stats.LoggingPerfStats('deep1 offline learning', perf_file)

    replay = replay_buffer.FileReadableReplayBuffer(REPLAY_FOLDER, perf=perf)
    try:
        print(f'loaded {len(replay)} experiences for replay...')
        if not os.path.exists(MODELFILE):
            _init_model()

        network = Deep1ModelTrain.load(MODELFILE)
        teacher = MyTeacher(FFTeacher())

        train_pwl = MyPWL(replay, Deep1ModelEval.load(EVAL_MODELFILE), teacher)
        test_pwl = train_pwl

        def update_target(ctx: tnr.GenericTrainingContext, hint: str):
            ctx.logger.info('swapping target network, hint=%s', hint)
            network.save(MODELFILE, exist_ok=True)

            new_target = Deep1ModelToEval(network.fc_layers)
            for _ in range(3):
                train_pwl.mark()
                for _ in range(0, 1024, ctx.batch_size):
                    train_pwl.fill(ctx.points, ctx.labels)
                    teacher.classify_many(new_target, ctx.points,
                                          ctx.labels.unsqueeze(1))
                new_target.learning_to_current()
                train_pwl.reset()

            new_target = new_target.to_evaluative()
            new_target.save(EVAL_MODELFILE, exist_ok=True)

            train_pwl.target_model = new_target

        trainer = tnr.GenericTrainer(
            train_pwl=train_pwl,
            test_pwl=test_pwl,
            teacher=teacher,
            batch_size=32,
            learning_rate=0.0001,
            optimizer=torch.optim.Adam(
                [p for p in network.parameters() if p.requires_grad],
                lr=0.0001),
            criterion=torch.nn.MSELoss())
        (trainer.reg(tnr.EpochsTracker()).reg(tnr.EpochsStopper(100)).reg(
            tnr.InfOrNANDetecter()).reg(tnr.InfOrNANStopper()).reg(
                tnr.DecayTracker()).reg(tnr.DecayStopper(1)).reg(
                    tnr.OnEpochCaller.create_every(update_target, skip=CUTOFF))
         # smaller cutoffs require more bootstrapping
         .reg(tnr.DecayOnPlateau()))
        res = trainer.train(network,
                            target_dtype=torch.float32,
                            point_dtype=torch.float32,
                            perf=perf)
        if res['inf_or_nan']:
            print('training failed! inf or nan!')
    finally:
        replay.close()
Ejemplo n.º 3
0
def _get_experiences_async(settings: TrainSettings, executable: str, port_min: int, port_max: int,
                           create_flags: int, aggressive: bool, spec: bool, nthreads: int):
    num_ticks_to_do = settings.current_session.tar_ticks
    if os.path.exists(settings.replay_folder):
        replay = rb.FileReadableReplayBuffer(settings.replay_folder)
        num_ticks_to_do -= len(replay)
        replay.close()

        if num_ticks_to_do <= 0:
            print(f'get_experiences_async nothing to do (already at {settings.replay_folder})')
            return

    replay_paths = [os.path.join(settings.bot_folder, f'replay_{i}') for i in range(nthreads)]
    setting_paths = [os.path.join(settings.bot_folder, f'settings_{i}.json')
                     for i in range(nthreads)]
    workers = []
    serd_settings = ser.serialize_embeddable(settings)
    ports_per = (port_max - port_min) // nthreads
    if ports_per < 3:
        raise ValueError('not enough ports assigned '
                         + f'({nthreads} threads, {port_max-port_min} ports)')
    ticks_per = int(math.ceil(num_ticks_to_do / nthreads))
    for worker in range(nthreads):
        proc = Process(target=_get_experiences_target,
                       args=(serd_settings, executable, port_min + worker*ports_per,
                             port_min + (worker+1)*ports_per, create_flags, aggressive, spec,
                             replay_paths[worker], setting_paths[worker], ticks_per))
        proc.start()
        workers.append(proc)
        time.sleep(1)

    for proc in workers:
        proc.join()

    print(f'get_experiences_async finished, storing in {settings.replay_folder}')
    if os.path.exists(settings.replay_folder):
        filetools.deldir(settings.replay_folder)

    if os.path.exists(settings.replay_folder):
        tmp_replay_folder = settings.replay_folder + '_tmp'
        os.rename(settings.replay_folder, tmp_replay_folder)
        replay_paths.append(tmp_replay_folder)

    if os.path.exists(HOLDOVER_DIR):
        replay_paths.append(HOLDOVER_DIR)

    rb.merge_buffers(replay_paths, settings.replay_folder)

    for path in replay_paths:
        filetools.deldir(path)
Ejemplo n.º 4
0
def _run(args):
    executable = 'python3' if args.py3 else 'python'
    port = 1769
    nthreads = args.numthreads

    settings = deep1_runner.TrainSettings(
        train_bot='or_reinforce.deep.deep1.deep1',
        adver_bot='optimax_rogue_bots.randombot.RandomBot',
        bot_folder=os.path.join('out', 'or_reinforce', 'deep', 'deep1'),
        train_seq=[
            deep1_runner.SessionSettings(
                tie_len=111,
                tar_ticks=2000,
                train_force_amount=args.train_force_amount)
        ],
        cur_ind=0)
    deep1_runner._get_experiences_async(  # pylint: disable=protected-access
        settings, executable, port, port + nthreads * 10, 0, False, False,
        nthreads)

    replay = replay_buffer.FileReadableReplayBuffer(deep1.REPLAY_FOLDER)
    try:
        print(f'loaded {len(replay)} experiences for analysis...')

        network = deep1.Deep1ModelEval.load(deep1.EVAL_MODELFILE)
        teacher = deep1.MyTeacher(FFTeacher())

        pwl = deep1.MyPWL(replay,
                          deep1.Deep1ModelEval.load(deep1.EVAL_MODELFILE),
                          teacher)

        print('--fetching top 2 pcs--')
        traj: pca_gen.PCTrajectoryGen = pca_gen.find_trajectory(
            network, pwl, 2)
        print('--plotting top 2 pcs--')
        pca_gen.plot_trajectory(traj,
                                os.path.join(SAVEDIR, 'pca'),
                                exist_ok=True,
                                transparent=False,
                                compress=False,
                                s=16)
        print('--finished--')
    finally:
        replay.close()
Ejemplo n.º 5
0
def get_unique_states_with_exps(
        replay_path: str) -> typing.Tuple[
            torch.tensor, typing.List[replay_buffer.Experience]]:
    """Gets the unique states and a corresponding representative experience
    for each state."""
    result = []
    result_exps = []

    buffer = replay_buffer.FileReadableReplayBuffer(replay_path)
    try:
        for _ in range(len(buffer)):
            exp: replay_buffer.Experience = next(buffer)
            as_torch = torch.from_numpy(exp.encoded_state)
            if all((existing != as_torch).sum() > 0 for existing in result):
                result.append(as_torch)
                result_exps.append(exp)
    finally:
        buffer.close()

    return torch.cat(tuple(i.unsqueeze(0) for i in result), dim=0), result_exps
Ejemplo n.º 6
0
def get_unique_states(replay_path: str) -> torch.tensor:
    """Gets the unique encoded states that are in the given replay folder.

    Arguments:
        replay_play (str): the path to where the replay experiences are stored

    Returns:
        unique_states (torch.tensor): the unique encoded game states within
            the experiences
    """
    result = []
    buffer = replay_buffer.FileReadableReplayBuffer(replay_path)
    try:
        for _ in range(len(buffer)):
            exp: replay_buffer.Experience = next(buffer)
            as_torch = torch.from_numpy(exp.encoded_state)
            if all((existing != as_torch).sum() > 0 for existing in result):
                result.append(as_torch)
    finally:
        buffer.close()

    return torch.cat(tuple(i.unsqueeze(0) for i in result), dim=0)
Ejemplo n.º 7
0
def main():
    """Main entry for tests"""

    if os.path.exists(FILEPATH):
        filetools.deldir(FILEPATH)

    buf = rb.FileWritableReplayBuffer(os.path.join(FILEPATH, '1'), exist_ok=False)

    sbuf = []

    for _ in range(5):
        exp = make_exp()
        buf.add(exp)
        sbuf.append(exp)

    buf2 = rb.FileWritableReplayBuffer(os.path.join(FILEPATH, '2'), exist_ok=False)

    for _ in range(5):
        exp = make_exp()
        buf2.add(exp)
        sbuf.append(exp)

    buf.close()
    buf2.close()

    rb.merge_buffers([os.path.join(FILEPATH, '2'), os.path.join(FILEPATH, '1')],
                     os.path.join(FILEPATH, '3'))

    buf.close()
    buf = rb.FileReadableReplayBuffer(os.path.join(FILEPATH, '3'))

    for _ in range(3):
        missing = [exp for exp in sbuf]
        for _ in range(10):
            got = buf.sample(1)[0]
            for i in range(len(missing)): #pylint: disable=consider-using-enumerate
                if got == missing[i]:
                    missing.pop(i)
                    break
            else:
                raise ValueError(f'got bad value: {got} expected one of \n'
                                 + '\n'.join(repr(exp) for exp in missing))

    buf.mark()
    got = buf.sample(1)[0]
    buf.reset()
    got2 = buf.sample(1)[0]
    if got != got2:
        raise ValueError(f'mark did not retrieve same experience: {got} vs {got2}')

    buf.close()

    buf = rb.MemoryPrioritizedReplayBuffer(os.path.join(FILEPATH, '3'))

    saw = []
    buf.mark()
    for _ in range(15):
        got = buf.sample(1)[0]
        saw.append(got)
    buf.reset()
    for _ in range(15):
        got = buf.sample(1)[0]
        if got != saw[0]:
            raise ValueError(f'got bad value: {got}, expected {saw[-1]}')
        saw.pop(0)

    for _ in range(15):
        got = buf.pop()[2]
        found = False
        for exp in sbuf:
            if got == exp:
                found = True
                got.last_td_error = random.random()
                exp.last_td_error = got.last_td_error
                buf.add(got)
                break
        if not found:
            raise ValueError(f'got {got}, expected one of '
                             + '\n'.join(repr(exp) for exp in sbuf))

    buf.close()