Exemple #1
0
def run(args):
    spec = load_spec("spec.json")
    exp_name = '-'.join([args.env.lower(), args.agent.lower()])
    BASE_LOG_DIR = os.path.join("data", exp_name)

    env = getattr(envs, args.env)(spec)
    agent = getattr(agents, args.agent)(env, spec)
    logger = logging.Logger(log_dir=BASE_LOG_DIR, file_name='episodic.h5')

    obs = env.reset()

    while True:
        env.render()

        with torch.no_grad():
            action = agent.act(obs)

        next_obs, reward, done, info = env.step(action)

        logger.record(**info)

        agent.update(obs, action, reward, next_obs, done)

        obs = next_obs

        if done:
            break

    env.close()
    logger.close()
Exemple #2
0
def test(path, mode, **kwargs):
    if mode == "gan":
        samplefiles = utils.parse_file([kwargs["samples"]], ext="h5")
        trainfiles = utils.parse_file(path, ext="pth")
        agent = gan.GAN(
            lr=PARAMS["GAN"]["lr"],
            x_size=PARAMS["GAN"]["x_size"],
            u_size=PARAMS["GAN"]["u_size"],
            z_size=PARAMS["GAN"]["z_size"],
        )

        for trainfile in trainfiles:
            agent.load(trainfile)
            agent.eval()

            logger = logging.Logger(path="data/tmp.h5", max_len=500)

            dataloader = gan.get_dataloader(samplefiles, shuffle=False)
            for i, (state, action) in enumerate(tqdm.tqdm(dataloader)):
                fake_action = agent.get_action(state)
                state, action = map(torch.squeeze, (state, action))
                fake_action = fake_action.ravel()
                logger.record(state=state,
                              action=action,
                              fake_action=fake_action)

            logger.close()

            print(f"Test data is saved in {logger.path}")
Exemple #3
0
def _sample_prog(i, log_dir):
    np.random.seed(i)
    env = envs.BaseEnv(
        initial_perturb=[0, 0, 0, 0.1],
        dt=0.01,
        max_t=20,
        solver="rk4",
        ode_step_len=1,
    )
    agent = agents.BaseAgent(env, theta_init=0)
    agent.add_noise(scale=0.03, tau=2)
    file_name = f"{i:03d}.h5"

    logger = logging.Logger(log_dir=log_dir, file_name=file_name, max_len=100)

    obs = env.reset("random")
    while True:
        action = agent.get_action(obs)
        next_obs, reward, done, info = env.step(action)
        logger.record(**info)
        obs = next_obs

        if done:
            break

    env.close()
    logger.close()
Exemple #4
0
def _sample(env, num, tqdm=False):
    envdir = env.envdir
    custom_range = trange if tqdm else range
    for i in custom_range(num):
        path = os.path.join(envdir, f"sample-{i:03d}.h5")
        env.logger = logging.Logger(path=path)
        env.reset("random")
        while True:
            done = env.step()
            if done:
                break

        env.close()
Exemple #5
0
    def __init__(self,
                 systems_dict,
                 dt=0.01,
                 max_t=1,
                 tmp_dir='data/tmp',
                 logging_off=True,
                 solver="odeint",
                 ode_step_len=2,
                 ode_option={},
                 name=None):
        self.name = name
        self.systems_dict = systems_dict
        self.systems = systems_dict.values()
        self.state_shape = (sum([
            functools.reduce(lambda a, b: a * b, system.state_shape)
            for system in self.systems
        ]), )
        self.indexing()

        if not hasattr(self, 'observation_space'):
            self.observation_space = infer_obs_space(systems_dict)
            print("Observation space is inferred using the initial states "
                  f"of the systems: {self.systems_dict.keys()}")

        if not hasattr(self, 'action_space'):
            raise NotImplementedError('The action_space is not defined.')

        self.clock = Clock(dt=dt, max_t=max_t)

        self.logging_off = logging_off
        if not logging_off:
            self.logger = logging.Logger(log_dir=tmp_dir,
                                         file_name='history.h5')

        # ODE Solver
        if solver == "odeint":
            self.solver = odeint
        elif solver == "rk4":
            self.solver = rk4

        self.ode_func = self.ode_wrapper(self.set_dot)
        self.ode_option = ode_option
        self.tqdm_bar = None

        if not isinstance(ode_step_len, int):
            raise ValueError("ode_step_len should be integer.")

        self.t_span = np.linspace(0, dt, ode_step_len + 1)

        self.delay = None
Exemple #6
0
def run(env):
    logger = logging.Logger(log_dir="data", file_name="tmp.h5")
    env.reset()
    while True:
        action = 0
        next_obs, reward, done, info = env.step(action)

        logger.record(**info)

        if done:
            break

    env.close()
    logger.close()
    return logger.path
Exemple #7
0
def main():
    env = Env()
    env.logger = logging.Logger("data/tmp.h5")

    env.reset()

    while True:
        env.render()
        done = env.step()
        if done:
            break

    env.close()

    figure.plot()
Exemple #8
0
def _train(agent):
    trainpath = agent.trainpath
    logger = logging.Logger(path=trainpath)
    logger.set_info(agent.get_info())

    params = agent.observe_dict()
    eps = 1e-15
    for epoch in range(50):
        params_next = agent.policy_evaluation()

        logger.record(epoch=epoch, params=params)

        if np.linalg.norm(params_next["wc"] - params["wc"]) < eps:
            break

        params = params_next

    logger.close()
Exemple #9
0
def sample(obj, **kwargs):
    np.random.seed(0)
    exps = kwargs["mode"]
    if kwargs["all"]:
        exps = ("even", "sparse", "det", "detmult")

    for exp in exps:
        samplepath = os.path.join(obj.sample_dir, exp + ".h5")

        print(f"Sample for {exp} ...")
        get_data = DataGen(exp, noise=kwargs["noise"])

        logger = logging.Logger(path=samplepath, max_len=500)

        t0 = time.time()
        for _ in tqdm.trange(10000):
            x, u, mask = get_data()
            logger.record(state=[x], action=[u], mask=[mask])

        logger.close()
        print(f"Saved in {samplepath}.")
        print(f"Elapsed time: {time.time() - t0:5.2f} seconds.")
Exemple #10
0
def run(path, **kwargs):
    logger = logging.Logger(log_dir=".", file_name=kwargs["out"], max_len=100)
    data = logging.load(path)
    expname = os.path.basename(path)
    envname, agentname, *_ = expname.split("-")
    env = getattr(envs, envname)(initial_perturb=[1, 0.0, 0,
                                                  np.deg2rad(10)],
                                 dt=0.01,
                                 max_t=40,
                                 solver="rk4",
                                 ode_step_len=1)
    agent = getattr(agents, agentname)(env,
                                       lrw=1e-2,
                                       lrv=1e-2,
                                       lrtheta=1e-2,
                                       w_init=0.03,
                                       v_init=0.03,
                                       theta_init=0,
                                       maxlen=100,
                                       batch_size=16)
    agent.load_weights(data)

    print(f"Runnning {expname} ...")
    _run(env, agent, logger, expname, **kwargs)

    logger.close()

    if kwargs["with_plot"]:
        import figures

        files = utils.parse_file(kwargs["out"])
        canvas = []
        for file in tqdm.tqdm(files):
            canvas = figures.plot_single(file, canvas=canvas)

        figures.show()
Exemple #11
0
    if os.path.exists("data"):
        if input(f"Delete \"data\"? [Y/n]: ") in ["", "Y", "y"]:
            shutil.rmtree("data")
        else:
            sys.exit()

    morphing_ctrl = Switching(env, seq=[0, 20])
    env.set_morphing_ctrl(morphing_ctrl)

    for cname, ctrl in ctrls.items():
        ctrl.name = cname
        env.set_ctrl(ctrl)

        expname = env.get_expname()
        path = os.path.join("data", "scene1", expname + ".h5")
        env.logger = logging.Logger(path=path)
        env.reset()
        while True:
            env.render()
            done = env.step()
            if done:
                break
        env.close()

    import matplotlib.pyplot as plt
    import os
    from glob import glob

    scenepath = os.path.join("data", "scene1")
    pathlist = glob(os.path.join(scenepath, "*.h5"))
    # cnamelist = [
Exemple #12
0
def train(sample, mode, **kwargs):
    samplefiles = utils.parse_file(sample, ext="h5")

    if mode == "gan" or mode == "all":
        torch.manual_seed(0)
        np.random.seed(0)

        gandir = kwargs["gan_dir"]
        histpath = os.path.join(gandir, "train-history.h5")

        print("Train GAN ...")

        agent = gan.GAN(
            lr=kwargs["gan_lr"],
            x_size=PARAMS["GAN"]["x_size"],
            u_size=PARAMS["GAN"]["u_size"],
            z_size=PARAMS["GAN"]["z_size"],
            use_cuda=kwargs["use_cuda"],
        )

        if kwargs["continue"] is not None:
            epoch_start = agent.load(kwargs["continue"])
            logger = logging.Logger(path=histpath,
                                    max_len=kwargs["save_interval"],
                                    mode="r+")
        else:
            epoch_start = 0
            logger = logging.Logger(path=histpath,
                                    max_len=kwargs["save_interval"])

        t0 = time.time()
        for epoch in tqdm.trange(epoch_start,
                                 epoch_start + 1 + kwargs["max_epoch"]):
            dataloader = gan.get_dataloader(samplefiles,
                                            shuffle=True,
                                            batch_size=kwargs["batch_size"])

            loss_d = loss_g = 0
            for i, data in enumerate(tqdm.tqdm(dataloader)):
                agent.set_input(data)
                agent.train()
                loss_d += agent.loss_d.mean().detach().numpy()
                loss_g += agent.loss_g.mean().detach().numpy()

            logger.record(epoch=epoch, loss_d=loss_d, loss_g=loss_g)

            if (epoch % kwargs["save_interval"] == 0
                    or epoch == epoch_start + 1 + kwargs["max_epoch"]):
                savepath = os.path.join(gandir, f"trained-{epoch:05d}.pth")
                agent.save(epoch, savepath)
                tqdm.tqdm.write(f"Weights are saved in {savepath}.")

        print(f"Elapsed time: {time.time() - t0:5.2f} sec")

    if mode == "copdac" or mode == "all":
        np.random.seed(1)

        env = envs.BaseEnv(initial_perturb=[0, 0, 0, 0.2])

        copdacdir = kwargs["copdac_dir"]

        agentname = "COPDAC"
        Agent = getattr(agents, agentname)
        agent = Agent(
            env,
            lrw=PARAMS["COPDAC"]["lrw"],
            lrv=PARAMS["COPDAC"]["lrv"],
            lrtheta=PARAMS["COPDAC"]["lrtheta"],
            w_init=PARAMS["COPDAC"]["w_init"],
            v_init=PARAMS["COPDAC"]["v_init"],
            theta_init=PARAMS["COPDAC"]["lrv"],
            maxlen=PARAMS["COPDAC"]["maxlen"],
            batch_size=PARAMS["COPDAC"]["batch_size"],
        )

        expname = "-".join([type(n).__name__ for n in (env, agent)])
        if kwargs["with_gan"]:
            expname += "-gan"
            agent.set_gan(kwargs["with_gan"], PARAMS["COPDAC"]["lrg"])

        if kwargs["with_reg"]:
            expname += "-reg"
            agent.set_reg(PARAMS["COPDAC"]["lrc"])

        histpath = os.path.join(copdacdir, expname + ".h5")
        if kwargs["continue"] is not None:
            epoch_start, i = agent.load(kwargs["continue"])
            logger = logging.Logger(path=histpath, max_len=100, mode="r+")
        else:
            epoch_start, i = 0, 0
            logger = logging.Logger(path=histpath, max_len=100)

        print(f"Training {agentname}...")

        epoch_end = epoch_start + kwargs["max_epoch"]
        for epoch in tqdm.trange(epoch_start, epoch_end):
            dataloader = gan.get_dataloader(samplefiles,
                                            keys=("state", "action", "reward",
                                                  "next_state"),
                                            shuffle=True,
                                            batch_size=64)

            for data in tqdm.tqdm(dataloader, desc=f"Epoch {epoch}"):
                agent.set_input(data)
                agent.train()

                if i % kwargs["save_interval"] == 0 or i == len(dataloader):
                    logger.record(epoch=epoch,
                                  i=i,
                                  w=agent.w,
                                  v=agent.v,
                                  theta=agent.theta,
                                  loss=agent.get_losses())

                i += 1

        logger.close()
Exemple #13
0
def train(obj, samples, **kwargs):
    for sample in samples:
        basedir = os.path.relpath(sample, obj.sample_dir)

        if os.path.isdir(sample):
            samplefiles = sorted(glob.glob(os.path.join(sample, "*.h5")))
        elif os.path.isfile(sample):
            samplefiles = sample
            basedir = os.path.splitext(basedir)[0]
        else:
            raise ValueError("unknown sample type.")

        gandir = os.path.join(obj.gan_dir, basedir)

        torch.save(
            samplefiles,
            os.path.join(gandir, "sample_path.h5"),
        )

        if kwargs["continue"] is None and os.path.exists(gandir):
            shutil.rmtree(gandir)
        os.makedirs(gandir, exist_ok=True)

        print(f"Train GAN for sample ({sample}) ...")

        save_interval = int(kwargs["save_interval"])

        agent = gan.GAN(
            lr=kwargs["lr"],
            x_size=1,
            u_size=1,
            z_size=kwargs["z_size"],
            use_cuda=kwargs["use_cuda"],
        )

        prog = functools.partial(
            _gan_prog,
            agent=agent,
            files=samplefiles,
            batch_size=kwargs["batch_size"],
        )

        histpath = os.path.join(gandir, "train_history.h5")

        if kwargs["continue"] is not None:
            epoch_start = agent.load(kwargs["continue"])
            logger = logging.Logger(path=histpath,
                                    max_len=kwargs["save_interval"],
                                    mode="r+")
        else:
            epoch_start = 0
            logger = logging.Logger(path=histpath,
                                    max_len=kwargs["save_interval"])

        t0 = time.time()
        for epoch in tqdm.trange(epoch_start,
                                 epoch_start + 1 + kwargs["max_epoch"]):
            loss_d, loss_g = prog(epoch)

            logger.record(epoch=epoch, loss_d=loss_d, loss_g=loss_g)

            if (epoch % save_interval == 0
                    or epoch == epoch_start + 1 + kwargs["max_epoch"]):
                savepath = os.path.join(gandir, f"trained_{epoch:05d}.pth")
                agent.save(epoch, savepath)
                tqdm.tqdm.write(f"Weights are saved in {savepath}.")

        logger.close()

        print(f"Elapsed time: {time.time() - t0:5.2f} sec")