def load_trials(trial_dir): directory = jup_dir + trial_dir ws_list = [] model_list = [] min_length = float('inf') for entry in os.scandir(directory): model, env, args, ws = load_workspace(entry.path) if len(ws["raw_rew_hist"]) < min_length: min_length = len(ws["raw_rew_hist"]) ws_list.append(ws) model_list.append(model) min_length = int(min_length) rewards = np.zeros((min_length, len(ws_list))) for i, ws in enumerate(ws_list): rewards[:, i] = np.array(ws["raw_rew_hist"][:min_length]) return ws_list, model_list, rewards
ep_acts = torch.stack(acts_list).reshape(-1, act_size) ep_rews = torch.stack(rews_list).reshape(-1, 1) return ep_obs1, ep_acts, ep_rews, None, ep_obs1 #%% jup_dir = "/home/sgillen/work/" trial_dir = "ssac/switched_rl/data_needle/50k_slow_longer" directory = jup_dir + trial_dir ws_list = [] model_list = [] min_length = float('inf') for entry in os.scandir(directory): model, env, args, ws = load_workspace(entry.path) if len(ws["raw_rew_hist"]) < min_length: min_length = len(ws["raw_rew_hist"]) ws_list.append(ws) model_list.append(model) min_length = int(min_length) rewards = np.zeros((min_length, len(ws_list))) for i, ws in enumerate(ws_list): rewards[:, i] = np.array(ws["raw_rew_hist"][:min_length]) print("seagul", rewards[-1, :].mean(), rewards[-1, :].std()) fig, ax = smooth_bounded_curve(rewards) ssac_size = rewards.shape[0]
import matplotlib.pyplot as plt import numpy as np from numpy import pi from stable_baselines import TD3 from stable_baselines.results_plotter import load_results import torch.utils.data from torch.multiprocessing import Pool from itertools import product import os import matplotlib script_path = os.path.realpath(__file__).split("/")[:-1] script_path = "/".join(script_path) + "/" print(script_path) model, env, args, ws = load_workspace( script_path + "data_needle/50k_slow_longer/trial3738150792--3-11_18-0") env_name = ws['env_name'] config = ws['env_config'] def load_trials(trial_dir): directory = script_path + trial_dir ws_list = [] model_list = [] min_length = float('inf') for entry in os.scandir(directory): model, env, args, ws = load_workspace(entry.path) if len(ws["raw_rew_hist"]) < min_length: min_length = len(ws["raw_rew_hist"])