Beispiel #1
0
def load_trials(trial_dir):
    directory = jup_dir + trial_dir

    ws_list = []
    model_list = []
    min_length = float('inf')
    for entry in os.scandir(directory):
        model, env, args, ws = load_workspace(entry.path)

        if len(ws["raw_rew_hist"]) < min_length:
            min_length = len(ws["raw_rew_hist"])

        ws_list.append(ws)
        model_list.append(model)

    min_length = int(min_length)
    rewards = np.zeros((min_length, len(ws_list)))
    for i, ws in enumerate(ws_list):
        rewards[:, i] = np.array(ws["raw_rew_hist"][:min_length])

    return ws_list, model_list, rewards
Beispiel #2
0
    ep_acts = torch.stack(acts_list).reshape(-1, act_size)
    ep_rews = torch.stack(rews_list).reshape(-1, 1)

    return ep_obs1, ep_acts, ep_rews, None, ep_obs1


#%%
jup_dir = "/home/sgillen/work/"
trial_dir = "ssac/switched_rl/data_needle/50k_slow_longer"
directory = jup_dir + trial_dir

ws_list = []
model_list = []
min_length = float('inf')
for entry in os.scandir(directory):
    model, env, args, ws = load_workspace(entry.path)

    if len(ws["raw_rew_hist"]) < min_length:
        min_length = len(ws["raw_rew_hist"])

    ws_list.append(ws)
    model_list.append(model)

min_length = int(min_length)
rewards = np.zeros((min_length, len(ws_list)))
for i, ws in enumerate(ws_list):
    rewards[:, i] = np.array(ws["raw_rew_hist"][:min_length])

print("seagul", rewards[-1, :].mean(), rewards[-1, :].std())
fig, ax = smooth_bounded_curve(rewards)
ssac_size = rewards.shape[0]
Beispiel #3
0
import matplotlib.pyplot as plt
import numpy as np
from numpy import pi
from stable_baselines import TD3
from stable_baselines.results_plotter import load_results
import torch.utils.data
from torch.multiprocessing import Pool
from itertools import product
import os
import matplotlib

script_path = os.path.realpath(__file__).split("/")[:-1]
script_path = "/".join(script_path) + "/"
print(script_path)

model, env, args, ws = load_workspace(
    script_path + "data_needle/50k_slow_longer/trial3738150792--3-11_18-0")
env_name = ws['env_name']
config = ws['env_config']


def load_trials(trial_dir):
    directory = script_path + trial_dir

    ws_list = []
    model_list = []
    min_length = float('inf')
    for entry in os.scandir(directory):
        model, env, args, ws = load_workspace(entry.path)

        if len(ws["raw_rew_hist"]) < min_length:
            min_length = len(ws["raw_rew_hist"])