Пример #1
0
 def load_model(env, algorithm, filename):
     if algorithm == "ddpg":
         return DDPG.load(filename, env=env)
     elif algorithm == "td3":
         return TD3.load(filename, env=env)
     elif algorithm == "sac":
         return SAC.load(filename, env=env)
     else:
         raise Exception("--> Alican's LOG: Unknown agent type!")
Пример #2
0
def test(MODEL_TEST):
    log_dir = "model_save/" + MODEL_PATH + "/" + MODEL_PATH + MODEL_TEST

    env = ENV(util='test', par=PARAM, dt=DT)
    env.render = True
    env = Monitor(env, log_dir)

    if PARAM['algo']=='td3':
        model = TD3.load(log_dir)
    elif PARAM['algo']=='ddpg':
        model = DDPG.load(log_dir)
    elif PARAM['algo']=='ppo':
        model = PPO.load(log_dir)

    # plot_results(f"model_save/")
    trade_dt = pd.DataFrame([])     # trade_dt: 所有股票的交易数据
    result_dt = pd.DataFrame([])    # result_dt: 所有股票一年测试结果数据
    for i in range(TEST_STOCK_NUM):
        state = env.reset()
        stock_bh_id = 'stock_bh_'+str(i)            # 记录每个股票交易的buy_hold
        stock_port_id = 'stock_port_'+str(i)        # 记录每个股票交易的portfolio
        stock_action_id = 'stock_action_' + str(i)  # 记录每个股票交易的action
        flow_L_id = 'stock_flow_' + str(i)          # 记录每个股票的流水
        stock_bh_dt, stock_port_dt, action_policy_dt, flow_L_dt = [], [], [], []
        day = 0
        while True:
            action = model.predict(state)
            next_state, reward, done, info = env.step(action[0])
            state = next_state
            # print("trying:",day,"reward:", reward,"now profit:",env.profit)   # 测试每一步的交易policy
            stock_bh_dt.append(env.buy_hold)
            stock_port_dt.append(env.Portfolio_unit)
            action_policy_dt.append(action[0][0])  # 用于记录policy
            flow_L_dt.append(env.flow)
            day+=1
            if done:
                print('stock: {}, total profit: {:.2f}%, buy hold: {:.2f}%, sp: {:.4f}, mdd: {:.2f}%, romad: {:.4f}'
                      .format(i, env.profit*100, env.buy_hold*100, env.sp, env.mdd*100, env.romad))
                # 交易完后记录:股票ID,利润(单位100%),buy_hold(单位100%),夏普率,最大回撤率(单位100%),romad
                result=pd.DataFrame([[i,env.profit*100,env.buy_hold*100,env.sp,env.mdd*100,env.romad]])
                break

        trade_dt_stock = pd.DataFrame({stock_port_id: stock_port_dt,
                                       stock_bh_id: stock_bh_dt,
                                       stock_action_id: action_policy_dt,
                                       flow_L_id: flow_L_dt})  # 支股票的交易数据

        trade_dt = pd.concat([trade_dt, trade_dt_stock], axis=1)    # 所有股票交易数据合并(加行)
        result_dt = pd.concat([result_dt,result],axis=0)            # 所有股票结果数据合并(加列)

    result_dt.columns = ['stock_id','prfit(100%)','buy_hold(100%)','sp','mdd(100%)','romad']
    trade_dt.to_csv('out_dt/trade_'+MODEL_PATH+'.csv',index=False)
    result_dt.to_csv('out_dt/result_'+MODEL_PATH+'.csv',index=False)
Пример #3
0
def main():
    args = parse_arguments()
    load_path = os.path.join("logs", args.env, args.agent, "best_model.zip")
    stats_path = os.path.join(args.log_dir, args.env, args.agent, "vec_normalize.pkl")

    if args.agent == 'ddpg':
        from stable_baselines3 import DDPG
        model = DDPG.load(load_path)
    elif args.agent == 'td3':
        from stable_baselines3 import TD3
        model = TD3.load(load_path)
    elif args.agent == 'ppo':
        from stable_baselines3 import PPO
        model = PPO.load(load_path)

    env = make_vec_env(args.env, n_envs=1)
    env = VecNormalize.load(stats_path, env)
    #  do not update them at test time
    env.training = False
    # reward normalization is not needed at test time
    env.norm_reward = False
    
    # env = gym.make(args.env)
    img = []
    if args.render:
        env.render('human')
    done = False
    obs = env.reset()
    action = model.predict(obs)
    if args.gif:
        img.append(env.render('rgb_array'))

    if args.timesteps is None:
        while not done: 
            action, _= model.predict(obs)
            obs, reward, done, info = env.step(action)
            if args.gif:
                img.append(env.render('rgb_array'))
            else:
                env.render()
    else:
        for i in range(args.timesteps): 
            action, _= model.predict(obs)
            obs, reward, done, info = env.step(action)
            if args.gif:
                img.append(env.render('rgb_array'))
            else:
                env.render()

    if args.gif:
        imageio.mimsave(f'{os.path.join("logs", args.env, args.agent, "recording.gif")}', [np.array(img) for i, img in enumerate(img) if i%2 == 0], fps=29)
Пример #4
0
def play():
    model = TD3.load("models/kuka_iiwa_insertion-v0")

    env = gym.make('kuka_iiwa_insertion-v0', use_gui=True)

    obs = env.reset()
    i = 0
    while True:
        i += 1
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        if i % 100 == 0 or dones: 
            print(obs, rewards, dones, info)
        if dones:
            print("="*20 + " RESET " + "="*20)
            env.reset()
Пример #5
0
def test_td3():
    log_dir = f"model_save/best_model_td3_cnn"
    env = ENV(istest=True)
    env.render = True
    env = Monitor(env, log_dir)
    model = TD3.load(log_dir)
    plot_results(f"model_save/")
    for i in range(10):
        state = env.reset()
        while True:
            action = model.predict(state)
            next_state, reward, done, info = env.step(action[0])
            state = next_state
            # print("trying:",i,"action:", action,"now profit:",env.profit)
            if done:
                print('stock',i,' total profit=',env.profit,' buy hold=',env.buy_hold)
                break
Пример #6
0
    def prepare_stage(self):

        dir = f'experiments/{self.config.experiment_name}'
        if not os.path.exists(dir):
            os.mkdir(dir)

        else:

            # recovers the latest non-corrupted checkpoint, if existent

            checkpoints = []
            for file in glob.glob(f'{dir}/status_checkpoint*'):
                checkpoints.append(
                    int(file.split('/status_checkpoint_')[1].split('.')[0]))
                checkpoints.sort()

            attempts = len(checkpoints) - 1

            while attempts >= 0:
                try:
                    f = open(
                        f'{dir}/status_checkpoint_{checkpoints[attempts]}.pkl',
                        'rb')
                    self.results_episodes, self.results_episodes_validation, self.current_checkpoint, self.current_episode = pickle.load(
                        f)

                    # only recovers pickle if model also available
                    env2 = DummyVecEnv([lambda: self.env])
                    self.model = TD3.load(
                        f'{dir}/model_checkpoint_{checkpoints[attempts]}',
                        env=env2)

                    attempts = -1

                    self.log.write(
                        f'RECOVERED checkpoint {checkpoints[attempts]}')

                except:
                    self.log.write(
                        f'ERROR: Could not recover checkpoint {checkpoints[attempts]}  {traceback.format_exc()}'
                    )
                    self.results_episodes, self.results_episodes_validation, self.current_checkpoint, self.current_episode = [], [], 0, 0

                attempts -= 1
Пример #7
0
def test_save_load_large_model(tmp_path):
    """
    Test saving and loading a model with a large policy that is greater than 2GB. We
    test only one algorithm since all algorithms share the same code for loading and
    saving the model.
    """
    env = select_env(TD3)
    kwargs = dict(policy_kwargs=dict(net_arch=[8192, 8192, 8192]),
                  device="cpu")
    model = TD3("MlpPolicy", env, **kwargs)

    # test saving
    model.save(tmp_path / "test_save")

    # test loading
    model = TD3.load(str(tmp_path / "test_save.zip"), env=env, **kwargs)

    # clear file from os
    os.remove(tmp_path / "test_save.zip")
Пример #8
0
def run(env, algname, filename):
    if algname == "TD3":
        model = TD3.load(f"{algname}_pkl")
    elif algname == "SAC":
        if filename:
            model = SAC.load(f"{filename}")
        else:
            model = SAC.load(f"{algname}_pkl")
    elif algname == "DDPG":
        model = DDPG.load(f"{algname}_pkl")
    else:
        raise "Wrong algorithm name provided."

    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, done, info = env.step(action)
        env.render()
        if done:
            break
Пример #9
0
def test_td3():
    log_dir = f"model_save/best_model_td3_sp2"
    env = ENV(istest=True)
    env.render = True
    env = Monitor(env, log_dir)
    model = TD3.load(log_dir)
    plot_results(f"model_save/")
    for i in range(10):
        state = env.reset()
        day = 0
        while True:
            action = model.predict(state)
            next_state, reward, done, info = env.step(action[0])
            state = next_state
            # print("trying:",day,"reward:", reward,"now profit:",env.profit)
            day+=1
            if done:
                print('stock: {}, total profit: {:.2f}%, buy hold: {:.2f}%, sp: {:.4f}, mdd: {:.2f}%, romad: {:.4f}'
                      .format(i, env.profit*100, env.buy_hold*100, env.sp, env.mdd*100, env.romad))
                break
Пример #10
0
        else:
            model_path = args.model

        model = None
        if args.algorithm == 'DQN':
            model = DQN.load(model_path, tensorboard_log=args.tensorboard)
        elif args.algorithm == 'DDPG':
            model = DDPG.load(model_path, tensorboard_log=args.tensorboard)
        elif args.algorithm == 'A2C':
            model = A2C.load(model_path, tensorboard_log=args.tensorboard)
        elif args.algorithm == 'PPO':
            model = PPO.load(model_path, tensorboard_log=args.tensorboard)
        elif args.algorithm == 'SAC':
            model = SAC.load(model_path, tensorboard_log=args.tensorboard)
        elif args.algorithm == 'TD3':
            model = TD3.load(model_path, tensorboard_log=args.tensorboard)
        else:
            raise RuntimeError('Algorithm specified is not registered.')

        model.set_env(env)

    # ---------------------------------------------------------------------------- #
    #       Calculating total training timesteps based on number of episodes       #
    # ---------------------------------------------------------------------------- #
    n_timesteps_episode = env.simulator._eplus_one_epi_len / \
        env.simulator._eplus_run_stepsize
    timesteps = args.episodes * n_timesteps_episode - 1

    # ---------------------------------------------------------------------------- #
    #                                   CALLBACKS                                  #
    # ---------------------------------------------------------------------------- #
Пример #11
0
        print("Training time: {}".format(t2 - t1))
        pprint(config)

        model.save("agents/{}_SB_policy".format(config["session_ID"]))
        env.close()

    if args["test"] and socket.gethostname() != "goedel":
        env_fun = my_utils.import_env(config["env_name"])
        config["seed"] = 1337

        env = env_fun(config)
        env.training = False
        env.norm_reward = False

        model = TD3.load("agents/{}".format(args["test_agent_path"]))
        # Normal testing
        N_test = 20
        total_rew = test_agent(env, model, deterministic=False, N=N_test)
        #total_rew = test_agent_mirrored(env, model, deterministic=False, N=N_test, perm=[-1, 0, 3, 2])
        print(f"Total test rew: {total_rew / N_test}")

        # Testing for permutation
        # N_test = 10
        # best_rew = 10000
        # best_perm = None
        # from itertools import permutations
        # for pos_perm in permutations(range(0, 4)):
        #     sign_perms = [[int(x) * 2 - 1 for x in list('{0:0b}'.format(i))] for i in range(16)]
        #     for sign_perm in sign_perms:
        #         pos_perm_copy = list(deepcopy(pos_perm))
Пример #12
0
import gym

from stable_baselines3 import TD3

env = gym.make('Pendulum-v0')

# check env
#from stable_baselines3.common.env_checker import check_env
#check_env(env)

model = TD3.load("td3_pendulum")

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

Пример #13
0
import gym
import numpy as np

import kuka_iiwa_insertion

from stable_baselines3 import TD3
from stable_baselines3.td3.policies import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

env = gym.make('kuka_iiwa_insertion-v0', use_gui=False)

# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions),
                                 sigma=0.1 * np.ones(n_actions))
try:
    model = TD3.load("models/kuka_iiwa_insertion-v0",
                     env,
                     action_noise=action_noise,
                     verbose=1)
except:
    model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)

i = 0
save_interval = 10000
while True:
    i += save_interval
    model.learn(total_timesteps=save_interval, log_interval=10)
    model.save("models/kuka_iiwa_insertion-v0")
Пример #14
0
 def load_weights(self, weights_file):
     """ Load the model from a zip archive """
     self.model = TD3.load(weights_file)
     pass
Пример #15
0
    args = parser.parse_args()
    env_id = args.env
    seed = args.seed
    ep_count = args.ep

    # Load environment
    env: LifeEnv = gym.make(env_id)
    env.seed(seed)

    # Load policy/model
    # dir_path = "/home/nize/happinize/tmp/trial_30/"
    dir_path = "/home/nize/rl-baselines3-zoo/logs/td3/happinize-v1_20/"
    model_path = os.path.join(dir_path, "best_model.zip")
    #model_path = os.path.join(dir_path, "happinize-v1.zip")
    model = TD3.load(model_path, env)

    # Render a few episodes
    for i in range(ep_count):
        renderSingleEpisode(env, model)

    # Get some stats by running a number of episodes
    lifetime_sample_count = 1000
    age_max_sample = 0
    age_lifetime_peaks: List = []
    savings_lifetime_peaks = []
    samples_list: List = []
    action_labels: List = [
        "risk_level", "requested_monthly_consumption", "occupation_index"
    ]
    #reward_label: List = ["total_happiness"]
Пример #16
0
else:
    model_path = args.model

model = None
if args.algorithm == 'DQN':
    model = DQN.load(model_path)
elif args.algorithm == 'DDPG':
    model = DDPG.load(model_path)
elif args.algorithm == 'A2C':
    model = A2C.load(model_path)
elif args.algorithm == 'PPO':
    model = PPO.load(model_path)
elif args.algorithm == 'SAC':
    model = SAC.load(model_path)
elif args.algorithm == 'TD3':
    model = TD3.load(model_path)
else:
    raise RuntimeError('Algorithm specified is not registered.')

# ---------------------------------------------------------------------------- #
#                             Execute loaded agent                             #
# ---------------------------------------------------------------------------- #
for i in range(args.episodes):
    obs = env.reset()
    rewards = []
    done = False
    current_month = 0
    while not done:
        a, _ = model.predict(obs)
        obs, reward, done, info = env.step(a)
        rewards.append(reward)
Пример #17
0
    algo = ARGS.exp.split("-")[2]

    if os.path.isfile(ARGS.exp + '/success_model.zip'):
        path = ARGS.exp + '/success_model.zip'
    elif os.path.isfile(ARGS.exp + '/best_model.zip'):
        path = ARGS.exp + '/best_model.zip'
    else:
        print("[ERROR]: no model under the specified path", ARGS.exp)
    if algo == 'a2c':
        model = A2C.load(path)
    if algo == 'ppo':
        model = PPO.load(path)
    if algo == 'sac':
        model = SAC.load(path)
    if algo == 'td3':
        model = TD3.load(path)
    if algo == 'ddpg':
        model = DDPG.load(path)

    #### Parameters to recreate the environment ################
    env_name = ARGS.exp.split("-")[1] + "-aviary-v0"
    OBS = ObservationType.KIN if ARGS.exp.split(
        "-")[3] == 'kin' else ObservationType.RGB
    if ARGS.exp.split("-")[4] == 'rpm':
        ACT = ActionType.RPM
    elif ARGS.exp.split("-")[4] == 'dyn':
        ACT = ActionType.DYN
    elif ARGS.exp.split("-")[4] == 'pid':
        ACT = ActionType.PID
    elif ARGS.exp.split("-")[4] == 'vel':
        ACT = ActionType.VEL
Пример #18
0
import numpy as np
import gym
import gym_fishing
from stable_baselines3 import TD3
from stable_baselines3.common.env_checker import check_env

env = gym.make('fishing-v1')
check_env(env)

load = False
if load:
    model = TD3.load("td3")
else:
    model = TD3('MlpPolicy', env, verbose=1)
    model.learn(total_timesteps=200)

## Simulate a run with the trained model, visualize result
df = env.simulate(model)
env.plot(df, "td3.png")

## Evaluate model
from stable_baselines3.common.evaluation import evaluate_policy

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print("mean reward:", mean_reward, "std:", std_reward)

## Save and reload the model
if not load:
    model.save("td3")
    model = TD3.load("td3")