forked from rll/deeprlhw2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
question5.py
76 lines (61 loc) · 2.54 KB
/
question5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from rl import ValueFunction, animate_rollout, pathlength
import mjcmdp
import ppo
import numpy as np
from tabulate import tabulate
import argparse
from prepare_h5_file import prepare_h5_file
from mujoco_policy import MujocoPolicy
from neural_value import NeuralValueFunction
class MujocoNeuralValueFunction(NeuralValueFunction):
def _features(self, path):
o = np.clip(path["observations"], -10,10)
l = pathlength(path)
al = np.arange(l).reshape(-1,1)/100.0
return np.concatenate([o, o**2, al, al**2, al**3, np.ones((l,1))], axis=1)
def main():
# Command line arguments
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--seed",type=int,default=0)
parser.add_argument("--outfile")
parser.add_argument("--metadata")
parser.add_argument("--plot",type=int,default=0)
parser.add_argument("--just_sim",action="store_true")
# Parameters
parser.add_argument("--n_iter",type=int,default=150)
parser.add_argument("--gamma",type=float,default=.99)
parser.add_argument("--lam",type=float,default=1.0)
parser.add_argument("--timesteps_per_batch",type=int,default=50000)
parser.add_argument("--penalty_coeff",type=float,default=0.5)
parser.add_argument("--max_pathlength",type=int,default=1000)
args = parser.parse_args()
# mdp = mjcmdp.CartpoleMDP()
np.random.seed(args.seed)
mdp = mjcmdp.HopperMDP()
(_,(ctrl_dim,)) = mdp.action_spec()
(_,(obs_dim,)) = mdp.observation_spec()
policy = MujocoPolicy(obs_dim, ctrl_dim)
# Saving to HDF5
hdf, diagnostics = prepare_h5_file(args, {"policy" : policy, "mdp" : mdp})
vf = MujocoNeuralValueFunction(num_features=38, num_hidden=40)
for (iteration,stats) in enumerate(ppo.run_ppo(
mdp, policy,
vf=vf,
gamma=args.gamma,
lam=args.lam,
max_pathlength = args.max_pathlength,
timesteps_per_batch = args.timesteps_per_batch,
n_iter = args.n_iter,
parallel=False,
penalty_coeff=args.penalty_coeff)):
std_a = policy.get_stdev()
for (i,s) in enumerate(std_a): stats["std_%i"%i] = s
print tabulate(stats.items())
for (statname, statval) in stats.items():
diagnostics[statname].append(statval)
if args.plot:
animate_rollout(mdp,policy,delay=.001,horizon=args.max_pathlength)
grp = hdf.create_group("snapshots/%.4i"%(iteration))
policy.pc.to_h5(grp)
if __name__ == "__main__":
main()