def run(args): spec = load_spec("spec.json") exp_name = '-'.join([args.env.lower(), args.agent.lower()]) BASE_LOG_DIR = os.path.join("data", exp_name) env = getattr(envs, args.env)(spec) agent = getattr(agents, args.agent)(env, spec) logger = logging.Logger(log_dir=BASE_LOG_DIR, file_name='episodic.h5') obs = env.reset() while True: env.render() with torch.no_grad(): action = agent.act(obs) next_obs, reward, done, info = env.step(action) logger.record(**info) agent.update(obs, action, reward, next_obs, done) obs = next_obs if done: break env.close() logger.close()
def test(path, mode, **kwargs): if mode == "gan": samplefiles = utils.parse_file([kwargs["samples"]], ext="h5") trainfiles = utils.parse_file(path, ext="pth") agent = gan.GAN( lr=PARAMS["GAN"]["lr"], x_size=PARAMS["GAN"]["x_size"], u_size=PARAMS["GAN"]["u_size"], z_size=PARAMS["GAN"]["z_size"], ) for trainfile in trainfiles: agent.load(trainfile) agent.eval() logger = logging.Logger(path="data/tmp.h5", max_len=500) dataloader = gan.get_dataloader(samplefiles, shuffle=False) for i, (state, action) in enumerate(tqdm.tqdm(dataloader)): fake_action = agent.get_action(state) state, action = map(torch.squeeze, (state, action)) fake_action = fake_action.ravel() logger.record(state=state, action=action, fake_action=fake_action) logger.close() print(f"Test data is saved in {logger.path}")
def _sample_prog(i, log_dir): np.random.seed(i) env = envs.BaseEnv( initial_perturb=[0, 0, 0, 0.1], dt=0.01, max_t=20, solver="rk4", ode_step_len=1, ) agent = agents.BaseAgent(env, theta_init=0) agent.add_noise(scale=0.03, tau=2) file_name = f"{i:03d}.h5" logger = logging.Logger(log_dir=log_dir, file_name=file_name, max_len=100) obs = env.reset("random") while True: action = agent.get_action(obs) next_obs, reward, done, info = env.step(action) logger.record(**info) obs = next_obs if done: break env.close() logger.close()
def _sample(env, num, tqdm=False): envdir = env.envdir custom_range = trange if tqdm else range for i in custom_range(num): path = os.path.join(envdir, f"sample-{i:03d}.h5") env.logger = logging.Logger(path=path) env.reset("random") while True: done = env.step() if done: break env.close()
def __init__(self, systems_dict, dt=0.01, max_t=1, tmp_dir='data/tmp', logging_off=True, solver="odeint", ode_step_len=2, ode_option={}, name=None): self.name = name self.systems_dict = systems_dict self.systems = systems_dict.values() self.state_shape = (sum([ functools.reduce(lambda a, b: a * b, system.state_shape) for system in self.systems ]), ) self.indexing() if not hasattr(self, 'observation_space'): self.observation_space = infer_obs_space(systems_dict) print("Observation space is inferred using the initial states " f"of the systems: {self.systems_dict.keys()}") if not hasattr(self, 'action_space'): raise NotImplementedError('The action_space is not defined.') self.clock = Clock(dt=dt, max_t=max_t) self.logging_off = logging_off if not logging_off: self.logger = logging.Logger(log_dir=tmp_dir, file_name='history.h5') # ODE Solver if solver == "odeint": self.solver = odeint elif solver == "rk4": self.solver = rk4 self.ode_func = self.ode_wrapper(self.set_dot) self.ode_option = ode_option self.tqdm_bar = None if not isinstance(ode_step_len, int): raise ValueError("ode_step_len should be integer.") self.t_span = np.linspace(0, dt, ode_step_len + 1) self.delay = None
def run(env): logger = logging.Logger(log_dir="data", file_name="tmp.h5") env.reset() while True: action = 0 next_obs, reward, done, info = env.step(action) logger.record(**info) if done: break env.close() logger.close() return logger.path
def main(): env = Env() env.logger = logging.Logger("data/tmp.h5") env.reset() while True: env.render() done = env.step() if done: break env.close() figure.plot()
def _train(agent): trainpath = agent.trainpath logger = logging.Logger(path=trainpath) logger.set_info(agent.get_info()) params = agent.observe_dict() eps = 1e-15 for epoch in range(50): params_next = agent.policy_evaluation() logger.record(epoch=epoch, params=params) if np.linalg.norm(params_next["wc"] - params["wc"]) < eps: break params = params_next logger.close()
def sample(obj, **kwargs): np.random.seed(0) exps = kwargs["mode"] if kwargs["all"]: exps = ("even", "sparse", "det", "detmult") for exp in exps: samplepath = os.path.join(obj.sample_dir, exp + ".h5") print(f"Sample for {exp} ...") get_data = DataGen(exp, noise=kwargs["noise"]) logger = logging.Logger(path=samplepath, max_len=500) t0 = time.time() for _ in tqdm.trange(10000): x, u, mask = get_data() logger.record(state=[x], action=[u], mask=[mask]) logger.close() print(f"Saved in {samplepath}.") print(f"Elapsed time: {time.time() - t0:5.2f} seconds.")
def run(path, **kwargs): logger = logging.Logger(log_dir=".", file_name=kwargs["out"], max_len=100) data = logging.load(path) expname = os.path.basename(path) envname, agentname, *_ = expname.split("-") env = getattr(envs, envname)(initial_perturb=[1, 0.0, 0, np.deg2rad(10)], dt=0.01, max_t=40, solver="rk4", ode_step_len=1) agent = getattr(agents, agentname)(env, lrw=1e-2, lrv=1e-2, lrtheta=1e-2, w_init=0.03, v_init=0.03, theta_init=0, maxlen=100, batch_size=16) agent.load_weights(data) print(f"Runnning {expname} ...") _run(env, agent, logger, expname, **kwargs) logger.close() if kwargs["with_plot"]: import figures files = utils.parse_file(kwargs["out"]) canvas = [] for file in tqdm.tqdm(files): canvas = figures.plot_single(file, canvas=canvas) figures.show()
if os.path.exists("data"): if input(f"Delete \"data\"? [Y/n]: ") in ["", "Y", "y"]: shutil.rmtree("data") else: sys.exit() morphing_ctrl = Switching(env, seq=[0, 20]) env.set_morphing_ctrl(morphing_ctrl) for cname, ctrl in ctrls.items(): ctrl.name = cname env.set_ctrl(ctrl) expname = env.get_expname() path = os.path.join("data", "scene1", expname + ".h5") env.logger = logging.Logger(path=path) env.reset() while True: env.render() done = env.step() if done: break env.close() import matplotlib.pyplot as plt import os from glob import glob scenepath = os.path.join("data", "scene1") pathlist = glob(os.path.join(scenepath, "*.h5")) # cnamelist = [
def train(sample, mode, **kwargs): samplefiles = utils.parse_file(sample, ext="h5") if mode == "gan" or mode == "all": torch.manual_seed(0) np.random.seed(0) gandir = kwargs["gan_dir"] histpath = os.path.join(gandir, "train-history.h5") print("Train GAN ...") agent = gan.GAN( lr=kwargs["gan_lr"], x_size=PARAMS["GAN"]["x_size"], u_size=PARAMS["GAN"]["u_size"], z_size=PARAMS["GAN"]["z_size"], use_cuda=kwargs["use_cuda"], ) if kwargs["continue"] is not None: epoch_start = agent.load(kwargs["continue"]) logger = logging.Logger(path=histpath, max_len=kwargs["save_interval"], mode="r+") else: epoch_start = 0 logger = logging.Logger(path=histpath, max_len=kwargs["save_interval"]) t0 = time.time() for epoch in tqdm.trange(epoch_start, epoch_start + 1 + kwargs["max_epoch"]): dataloader = gan.get_dataloader(samplefiles, shuffle=True, batch_size=kwargs["batch_size"]) loss_d = loss_g = 0 for i, data in enumerate(tqdm.tqdm(dataloader)): agent.set_input(data) agent.train() loss_d += agent.loss_d.mean().detach().numpy() loss_g += agent.loss_g.mean().detach().numpy() logger.record(epoch=epoch, loss_d=loss_d, loss_g=loss_g) if (epoch % kwargs["save_interval"] == 0 or epoch == epoch_start + 1 + kwargs["max_epoch"]): savepath = os.path.join(gandir, f"trained-{epoch:05d}.pth") agent.save(epoch, savepath) tqdm.tqdm.write(f"Weights are saved in {savepath}.") print(f"Elapsed time: {time.time() - t0:5.2f} sec") if mode == "copdac" or mode == "all": np.random.seed(1) env = envs.BaseEnv(initial_perturb=[0, 0, 0, 0.2]) copdacdir = kwargs["copdac_dir"] agentname = "COPDAC" Agent = getattr(agents, agentname) agent = Agent( env, lrw=PARAMS["COPDAC"]["lrw"], lrv=PARAMS["COPDAC"]["lrv"], lrtheta=PARAMS["COPDAC"]["lrtheta"], w_init=PARAMS["COPDAC"]["w_init"], v_init=PARAMS["COPDAC"]["v_init"], theta_init=PARAMS["COPDAC"]["lrv"], maxlen=PARAMS["COPDAC"]["maxlen"], batch_size=PARAMS["COPDAC"]["batch_size"], ) expname = "-".join([type(n).__name__ for n in (env, agent)]) if kwargs["with_gan"]: expname += "-gan" agent.set_gan(kwargs["with_gan"], PARAMS["COPDAC"]["lrg"]) if kwargs["with_reg"]: expname += "-reg" agent.set_reg(PARAMS["COPDAC"]["lrc"]) histpath = os.path.join(copdacdir, expname + ".h5") if kwargs["continue"] is not None: epoch_start, i = agent.load(kwargs["continue"]) logger = logging.Logger(path=histpath, max_len=100, mode="r+") else: epoch_start, i = 0, 0 logger = logging.Logger(path=histpath, max_len=100) print(f"Training {agentname}...") epoch_end = epoch_start + kwargs["max_epoch"] for epoch in tqdm.trange(epoch_start, epoch_end): dataloader = gan.get_dataloader(samplefiles, keys=("state", "action", "reward", "next_state"), shuffle=True, batch_size=64) for data in tqdm.tqdm(dataloader, desc=f"Epoch {epoch}"): agent.set_input(data) agent.train() if i % kwargs["save_interval"] == 0 or i == len(dataloader): logger.record(epoch=epoch, i=i, w=agent.w, v=agent.v, theta=agent.theta, loss=agent.get_losses()) i += 1 logger.close()
def train(obj, samples, **kwargs): for sample in samples: basedir = os.path.relpath(sample, obj.sample_dir) if os.path.isdir(sample): samplefiles = sorted(glob.glob(os.path.join(sample, "*.h5"))) elif os.path.isfile(sample): samplefiles = sample basedir = os.path.splitext(basedir)[0] else: raise ValueError("unknown sample type.") gandir = os.path.join(obj.gan_dir, basedir) torch.save( samplefiles, os.path.join(gandir, "sample_path.h5"), ) if kwargs["continue"] is None and os.path.exists(gandir): shutil.rmtree(gandir) os.makedirs(gandir, exist_ok=True) print(f"Train GAN for sample ({sample}) ...") save_interval = int(kwargs["save_interval"]) agent = gan.GAN( lr=kwargs["lr"], x_size=1, u_size=1, z_size=kwargs["z_size"], use_cuda=kwargs["use_cuda"], ) prog = functools.partial( _gan_prog, agent=agent, files=samplefiles, batch_size=kwargs["batch_size"], ) histpath = os.path.join(gandir, "train_history.h5") if kwargs["continue"] is not None: epoch_start = agent.load(kwargs["continue"]) logger = logging.Logger(path=histpath, max_len=kwargs["save_interval"], mode="r+") else: epoch_start = 0 logger = logging.Logger(path=histpath, max_len=kwargs["save_interval"]) t0 = time.time() for epoch in tqdm.trange(epoch_start, epoch_start + 1 + kwargs["max_epoch"]): loss_d, loss_g = prog(epoch) logger.record(epoch=epoch, loss_d=loss_d, loss_g=loss_g) if (epoch % save_interval == 0 or epoch == epoch_start + 1 + kwargs["max_epoch"]): savepath = os.path.join(gandir, f"trained_{epoch:05d}.pth") agent.save(epoch, savepath) tqdm.tqdm.write(f"Weights are saved in {savepath}.") logger.close() print(f"Elapsed time: {time.time() - t0:5.2f} sec")