コード例 #1
0
ファイル: taylor.py プロジェクト: BCHoagland/Continuity
def train(algo,
          env_name,
          num_timesteps,
          lr,
          noise,
          batch_size,
          vis_iter,
          seed=0,
          log=False,
          taylor_coef=0.5):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # create env and models
    env = Env(env_name, seed=seed)

    # set up algo
    n_s = env.state_dim()
    n_a = env.action_dim()
    algo = algo(taylor_coef)
    algo.create_models(lr, n_s, n_a, env.action_space)

    # create storage and add random transitions to it
    storage = Storage(1e6)
    explore(10000, env, storage)

    # training loop
    last_ep_cost = 0
    ep_cost = 0

    s = env.reset()
    for step in range(int(num_timesteps)):
        # interact with env
        with torch.no_grad():
            s, a, c, s2, done = algo.interact(s, env, noise)
        storage.store((s, a, c, s2, done))

        # cost bookkeeping
        ep_cost += c.item()

        # algo update
        algo.update(storage, batch_size)

        # transition to next state + cost bookkeeping
        if done:
            s = env.reset()
            last_ep_cost = ep_cost
            ep_cost = 0
        else:
            s = s2

        # report progress
        if step % vis_iter == vis_iter - 1:
            if log:
                wandb.log({'Average episodic cost': last_ep_cost}, step=step)
            else:
                print(f'Step: {step} | Cost: {last_ep_cost}')