Ejemplo n.º 1
0
    def test_learning_rnn(self):
        def rew_func(next_obs,
                     acs,
                     mean_obs=0.,
                     std_obs=1.,
                     mean_acs=0.,
                     std_acs=1.):
            next_obs = next_obs * std_obs + mean_obs
            acs = acs * std_acs + mean_acs
            # Pendulum
            rews = -(torch.acos(next_obs[:, 0].clamp(min=-1, max=1))**2 + 0.1 *
                     (next_obs[:, 2].clamp(min=-8, max=8)**2) +
                     0.001 * acs.squeeze(-1)**2)
            rews = rews.squeeze(0)

            return rews

        # init models
        dm_net = ModelNetLSTM(self.env.observation_space,
                              self.env.action_space)
        dm = DeterministicSModel(self.env.observation_space,
                                 self.env.action_space,
                                 dm_net,
                                 rnn=True,
                                 data_parallel=False,
                                 parallel_dim=0)

        mpc_pol = MPCPol(self.env.observation_space,
                         self.env.action_space,
                         dm_net,
                         rew_func,
                         1,
                         1,
                         mean_obs=0.,
                         std_obs=1.,
                         mean_acs=0.,
                         std_acs=1.,
                         rnn=True)
        optim_dm = torch.optim.Adam(dm_net.parameters(), 1e-3)

        # sample with mpc policy
        sampler = EpiSampler(self.env, mpc_pol, num_parallel=1)
        epis = sampler.sample(mpc_pol, max_epis=1)

        traj = Traj()
        traj.add_epis(epis)
        traj = ef.add_next_obs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()
        traj.add_traj(traj)

        # train
        result_dict = mpc.train_dm(traj, dm, optim_dm, epoch=1, batch_size=1)

        del sampler
Ejemplo n.º 2
0
# obs, next_obs, and acs should become mean 0, std 1
traj, mean_obs, std_obs, mean_acs, std_acs = ef.normalize_obs_and_acs(traj)
traj.register_epis()

del rand_sampler

### Train Dynamics Model ###

# initialize dynamics model and mpc policy
if args.rnn:
    dm_net = ModelNetLSTM(ob_space, ac_space)
else:
    dm_net = ModelNet(ob_space, ac_space)
dm = DeterministicSModel(ob_space,
                         ac_space,
                         dm_net,
                         args.rnn,
                         data_parallel=args.data_parallel,
                         parallel_dim=1 if args.rnn else 0)
mpc_pol = MPCPol(ob_space, ac_space, dm_net, rew_func, args.n_samples,
                 args.horizon_of_samples, mean_obs, std_obs, mean_acs, std_acs,
                 args.rnn)
optim_dm = torch.optim.Adam(dm_net.parameters(), args.dm_lr)

rl_sampler = EpiSampler(env,
                        mpc_pol,
                        num_parallel=args.num_parallel,
                        seed=args.seed)

# train loop
total_epi = 0
total_step = 0
Ejemplo n.º 3
0
traj = ef.add_next_obs(traj)
traj = ef.compute_h_masks(traj)
# obs, next_obs, and acs should become mean 0, std 1
traj, mean_obs, std_obs, mean_acs, std_acs = ef.normalize_obs_and_acs(traj)
traj.register_epis()

del rand_sampler

### Train Dynamics Model ###

# initialize dynamics model and mpc policy
if args.rnn:
    dm_net = ModelNetLSTM(observation_space, action_space)
else:
    dm_net = ModelNet(observation_space, action_space)
dm = DeterministicSModel(observation_space, action_space, dm_net, args.rnn)
mpc_pol = MPCPol(observation_space, action_space, dm_net, rew_func,
                 args.n_samples, args.horizon_of_samples, mean_obs, std_obs,
                 mean_acs, std_acs, args.rnn)
optim_dm = torch.optim.Adam(dm_net.parameters(), args.dm_lr)

rl_sampler = EpiSampler(env,
                        mpc_pol,
                        num_parallel=args.num_parallel,
                        seed=args.seed)

# train loop
total_epi = 0
total_step = 0
counter_agg_iters = 0
max_rew = -1e+6