def train(algo, env_name, num_timesteps, lr, noise, batch_size, vis_iter, seed=0, log=False, taylor_coef=0.5): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) # create env and models env = Env(env_name, seed=seed) # set up algo n_s = env.state_dim() n_a = env.action_dim() algo = algo(taylor_coef) algo.create_models(lr, n_s, n_a, env.action_space) # create storage and add random transitions to it storage = Storage(1e6) explore(10000, env, storage) # training loop last_ep_cost = 0 ep_cost = 0 s = env.reset() for step in range(int(num_timesteps)): # interact with env with torch.no_grad(): s, a, c, s2, done = algo.interact(s, env, noise) storage.store((s, a, c, s2, done)) # cost bookkeeping ep_cost += c.item() # algo update algo.update(storage, batch_size) # transition to next state + cost bookkeeping if done: s = env.reset() last_ep_cost = ep_cost ep_cost = 0 else: s = s2 # report progress if step % vis_iter == vis_iter - 1: if log: wandb.log({'Average episodic cost': last_ep_cost}, step=step) else: print(f'Step: {step} | Cost: {last_ep_cost}')