Exemple #1
0
def fn_opt(trial: optuna.Trial) -> int:
    try:
        net_arch = trial.suggest_categorical('net_arch',
                                             ['CnnPolicy', 'LnCnnPolicy'])
        gamma = trial.suggest_categorical(
            'gamma', [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999])
        learning_rate = trial.suggest_loguniform('lr', 1e-5, 1)
        batch_size = trial.suggest_categorical('batch_size',
                                               [16, 32, 64, 128, 256, 512])
        buffer_size = trial.suggest_categorical(
            'buffer_size', [int(1e4), int(1e5), int(1e6)])
        learning_starts = trial.suggest_categorical('learning_starts',
                                                    [0, 25, 50, 75, 100])
        gradient_steps = trial.suggest_categorical('gradient_steps',
                                                   [5, 15, 20, 50, 100, 300])
        ent_coef = trial.suggest_categorical(
            'ent_coef', ['auto', 0.5, 0.1, 0.05, 0.01, 0.0001])

        target_entropy = 'auto'
        if ent_coef == 'auto':
            target_entropy = trial.suggest_categorical(
                'target_entropy', ['auto', -1, -10, -20, -50, -100])

        env = fn_gym.FNGym(0.2)
        model = fn_sac.FNSAC(net_arch,
                             env,
                             gamma=gamma,
                             learning_rate=learning_rate,
                             batch_size=batch_size,
                             buffer_size=buffer_size,
                             learning_starts=learning_starts,
                             gradient_steps=gradient_steps,
                             ent_coef=ent_coef,
                             target_entropy=target_entropy)

        for train_count in range(10):
            model.learn(total_timesteps=200)
            trial.report(env.get_running_reward(), (train_count + 1) * 200)

            if trial.should_prune():
                raise optuna.TrialPruned()
    except KeyboardInterrupt:
        input('Keyboard Interrupt. Press any key to continue')
        raise ValueError("Exit Trial, Keyboard Interrupt")

    return env.get_running_reward()
import fnai.fn_gym as fn_gym
from stable_baselines import SAC 
from stable_baselines.common.policies import CnnLnLstmPolicy
from stable_baselines import PPO2
from os import path

parser = argparse.ArgumentParser(description='Train the model')
parser.add_argument(nargs='?', default='none', dest='model_path',
    help='(optional) parameters to start training from')

args = parser.parse_args()




env = fn_gym.FNGym(0.2)
print(env.observation_space)
print(env.action_space)
model = PPO2(CnnLnLstmPolicy, env, tensorboard_log='sac_fn_tensorboard', nminibatches=1)

if path.exists(args.model_path):
    print(f'Loading Model: {args.model_path}')
    model.load(args.model_path)
else:
    print('Using new model')

keep_training=True
log_num = 0
manual_log_num = 0
train_interval = 100000
while keep_training:
import fnai.fn_gym as fn_gym
import time
import numpy as np
from PIL import Image

test_gym = fn_gym.FNGym(0.2)
init_img = test_gym.reset()
for i in range(500):
    screenshot = test_gym.d3d_buff.screenshot(region=test_gym.win_coords)
    print(screenshot[25, 10, 0])

"""
print(np.shape(init_img))
obs, rew, done, info = test_gym.step([1.0, 1.0, 1.0])
print(np.shape(obs))
Image.fromarray(obs).save('test.png')
"""
import stable_baselines.common.env_checker as env_checker
import fnai.fn_gym as fn_gym

test_gym = fn_gym.FNGym(0.2, True)
env_checker.check_env(test_gym, warn=True, skip_render_check=True)