Пример #1
0
action_space = env.action_space

if args.rnn:
    pol_net = PolNetLSTM(observation_space, action_space,
                         h_size=256, cell_size=256)
else:
    pol_net = PolNet(observation_space, action_space)

pol = GaussianPol(observation_space, action_space, pol_net, args.rnn,
                    data_parallel=args.data_parallel, parallel_dim=1 if args.rnn else 0)

if args.rnn:
    vf_net = VNetLSTM(observation_space, h_size=256, cell_size=256)
else:
    vf_net = VNet(observation_space)
vf = DeterministicSVfunc(observation_space, vf_net, args.rnn,
                         data_parallel=args.data_parallel, parallel_dim=1 if args.rnn else 0)

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)

total_epi = 0
total_step = 0
max_rew = -1e6
while args.max_epis > total_epi:
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=args.max_steps_per_iter)
    with measure('train'):
        traj = Traj()
        traj.add_epis(epis)
Пример #2
0
    def test_learning(self):
        observation_space = self.env.real_observation_space
        skill_space = self.env.skill_space
        ob_skill_space = self.env.observation_space
        action_space = self.env.action_space
        ob_dim = ob_skill_space.shape[0] - 4
        f_dim = ob_dim

        def discrim_f(x):
            return x

        pol_net = PolNet(ob_skill_space, action_space)
        pol = GaussianPol(ob_skill_space, action_space, pol_net)
        qf_net1 = QNet(ob_skill_space, action_space)
        qf1 = DeterministicSAVfunc(ob_skill_space, action_space, qf_net1)
        targ_qf_net1 = QNet(ob_skill_space, action_space)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(ob_skill_space, action_space,
                                        targ_qf_net1)
        qf_net2 = QNet(ob_skill_space, action_space)
        qf2 = DeterministicSAVfunc(ob_skill_space, action_space, qf_net2)
        targ_qf_net2 = QNet(ob_skill_space, action_space)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(ob_skill_space, action_space,
                                        targ_qf_net2)
        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]
        log_alpha = nn.Parameter(torch.ones(()))

        high = np.array([np.finfo(np.float32).max] * f_dim)
        f_space = gym.spaces.Box(-high, high, dtype=np.float32)
        discrim_net = DiaynDiscrimNet(f_space,
                                      skill_space,
                                      h_size=100,
                                      discrim_f=discrim_f)
        discrim = DeterministicSVfunc(f_space, discrim_net)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 1e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 1e-4)
        optim_discrim = torch.optim.SGD(discrim.parameters(),
                                        lr=0.001,
                                        momentum=0.9)

        off_traj = Traj()
        sampler = EpiSampler(self.env, pol, num_parallel=1)

        epis = sampler.sample(pol, max_steps=200)
        on_traj = Traj()
        on_traj.add_epis(epis)
        on_traj = ef.add_next_obs(on_traj)
        on_traj = ef.compute_diayn_rews(
            on_traj, lambda x: diayn_sac.calc_rewards(x, 4, discrim))
        on_traj.register_epis()
        off_traj.add_traj(on_traj)
        step = on_traj.num_step
        log_alpha = nn.Parameter(np.log(0.1) * torch.ones(()))  # fix alpha
        result_dict = diayn_sac.train(off_traj, pol, qfs, targ_qfs, log_alpha,
                                      optim_pol, optim_qfs, optim_alpha, step,
                                      128, 5e-3, 0.99, 1, discrim, 4, True)
        discrim_losses = diayn.train(discrim, optim_discrim, on_traj, 32, 100,
                                     4)

        del sampler
Пример #3
0
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space,
                         h1=32,
                         h2=32)
        pol = GaussianPol(self.env.observation_space, self.env.action_space,
                          pol_net)

        vf_net = VNet(self.env.observation_space)
        vf = DeterministicSVfunc(self.env.observation_space, vf_net)

        rewf_net = VNet(self.env.observation_space, h1=32, h2=32)
        rewf = DeterministicSVfunc(self.env.observation_space, rewf_net)
        shaping_vf_net = VNet(self.env.observation_space, h1=32, h2=32)
        shaping_vf = DeterministicSVfunc(self.env.observation_space,
                                         shaping_vf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_vf = torch.optim.Adam(vf_net.parameters(), 3e-4)
        optim_discrim = torch.optim.Adam(
            list(rewf_net.parameters()) + list(shaping_vf_net.parameters()),
            3e-4)

        with open(os.path.join('data/expert_epis', 'Pendulum-v0_2epis.pkl'),
                  'rb') as f:
            expert_epis = pickle.load(f)
        expert_traj = Traj()
        expert_traj.add_epis(expert_epis)
        expert_traj = ef.add_next_obs(expert_traj)
        expert_traj.register_epis()

        epis = sampler.sample(pol, max_steps=32)

        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.add_next_obs(agent_traj)
        agent_traj = ef.compute_pseudo_rews(agent_traj,
                                            rew_giver=rewf,
                                            state_only=True)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, 0.99)
        agent_traj = ef.compute_advs(agent_traj, 0.99, 0.95)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        result_dict = airl.train(agent_traj,
                                 expert_traj,
                                 pol,
                                 vf,
                                 optim_vf,
                                 optim_discrim,
                                 rewf=rewf,
                                 shaping_vf=shaping_vf,
                                 rl_type='trpo',
                                 epoch=1,
                                 batch_size=32,
                                 discrim_batch_size=32,
                                 discrim_step=1,
                                 pol_ent_beta=1e-3,
                                 gamma=0.99)

        del sampler
Пример #4
0
                                parallel_dim=0)
qfs = [qf1, qf2]
targ_qfs = [targ_qf1, targ_qf2]

log_alpha = nn.Parameter(torch.ones((), device=device))

high = np.array([np.finfo(np.float32).max] * f_dim)
f_space = gym.spaces.Box(-high, high, dtype=np.float32)
discrim_net = DiaynDiscrimNet(f_space,
                              skill_space,
                              h_size=args.discrim_h_size,
                              discrim_f=discrim_f).to(device)

discrim = DeterministicSVfunc(f_space,
                              discrim_net,
                              rnn=False,
                              data_parallel=False,
                              parallel_dim=0)

# set optimizer to both models
optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_qf1 = torch.optim.Adam(qf_net1.parameters(), args.qf_lr)
optim_qf2 = torch.optim.Adam(qf_net2.parameters(), args.qf_lr)
optim_qfs = [optim_qf1, optim_qf2]
optim_alpha = torch.optim.Adam([log_alpha], args.pol_lr)
optim_discrim = torch.optim.SGD(discrim.parameters(),
                                lr=args.discrim_lr,
                                momentum=args.discrim_momentum)

off_traj = Traj()
sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)
Пример #5
0
qf_net2 = QNet(ob_skill_space, action_space)
qf2 = DeterministicSAVfunc(ob_skill_space, action_space, qf_net2)
targ_qf_net2 = QNet(ob_skill_space, action_space)
targ_qf_net2.load_state_dict(qf_net2.state_dict())
targ_qf2 = DeterministicSAVfunc(ob_skill_space, action_space, targ_qf_net2)
qfs = [qf1, qf2]
targ_qfs = [targ_qf1, targ_qf2]

log_alpha = nn.Parameter(torch.ones((), device=device))

high = np.array([np.finfo(np.float32).max]*f_dim)
f_space = gym.spaces.Box(-high, high, dtype=np.float32)
discrim_net = DiaynDiscrimNet(
    f_space, skill_space, h_size=args.discrim_h_size, discrim_f=discrim_f).to(device)

discrim = DeterministicSVfunc(f_space, discrim_net, rnn=False)


# set optimizer to both models
optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_qf1 = torch.optim.Adam(qf_net1.parameters(), args.qf_lr)
optim_qf2 = torch.optim.Adam(qf_net2.parameters(), args.qf_lr)
optim_qfs = [optim_qf1, optim_qf2]
optim_alpha = torch.optim.Adam([log_alpha], args.pol_lr)
optim_discrim = torch.optim.SGD(discrim.parameters(
), lr=args.discrim_lr, momentum=args.discrim_momentum)

off_traj = Traj()
sampler = EpiSampler(
    env, pol, num_parallel=args.num_parallel, seed=args.seed)
Пример #6
0
                              data_parallel=args.data_parallel,
                              parallel_dim=0)
else:
    raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

if args.pol:
    pol.load_state_dict(
        torch.load(args.pol, map_location=lambda storage, loc: storage))

vf_net = VNetCNP(ob_space,
                 h_size=args.h_size,
                 r_size=args.r_size,
                 aggregation=args.aggregation,
                 use_pe=args.use_pe)
vf = DeterministicSVfunc(ob_space,
                         vf_net,
                         data_parallel=args.data_parallel,
                         parallel_dim=0)

if args.vf:
    vf.load_state_dict(
        torch.load(args.vf, map_location=lambda storage, loc: storage))

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)
if center_env is not None:
    center_sampler = EpiSampler(center_env,
                                pol,
                                num_parallel=1,
                                seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
if args.optim_pol:
Пример #7
0
elif isinstance(action_space, gym.spaces.Discrete):
    pol = CategoricalPol(observation_space,
                         action_space,
                         pol_net,
                         data_parallel=args.data_parallel)
elif isinstance(action_space, gym.spaces.MultiDiscrete):
    pol = MultiCategoricalPol(observation_space,
                              action_space,
                              pol_net,
                              data_parallel=args.data_parallel)
else:
    raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

vf_net = VNet(observation_space)
vf = DeterministicSVfunc(observation_space,
                         vf_net,
                         data_parallel=args.data_parallel)

if args.rew_type == 'rew':
    rewf_net = VNet(observation_space, h1=args.discrim_h1, h2=args.discrim_h2)
    rewf = DeterministicSVfunc(observation_space,
                               rewf_net,
                               data_parallel=args.data_parallel)
    shaping_vf_net = VNet(observation_space,
                          h1=args.discrim_h1,
                          h2=args.discrim_h2)
    shaping_vf = DeterministicSVfunc(observation_space,
                                     shaping_vf_net,
                                     data_parallel=args.data_parallel)
    optim_discrim = torch.optim.Adam(
        list(rewf_net.parameters()) + list(shaping_vf_net.parameters()),
    pol_net = PolNet(observation_space, action_space)
if isinstance(action_space, gym.spaces.Box):
    pol = GaussianPol(observation_space, action_space, pol_net, args.rnn)
elif isinstance(action_space, gym.spaces.Discrete):
    pol = CategoricalPol(observation_space, action_space, pol_net, args.rnn)
elif isinstance(action_space, gym.spaces.MultiDiscrete):
    pol = MultiCategoricalPol(observation_space, action_space, pol_net,
                              args.rnn)
else:
    raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

if args.rnn:
    vf_net = VNetLSTM(observation_space, h_size=256, cell_size=256)
else:
    vf_net = VNet(observation_space)
vf = DeterministicSVfunc(observation_space, vf_net, args.rnn)

if dist.get_rank() == 0:
    sampler = DistributedEpiSampler(args.sampler_world_size,
                                    env=env,
                                    pol=pol,
                                    num_parallel=args.num_parallel,
                                    seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)

total_epi = 0
total_step = 0
max_rew = -1e6
kl_beta = args.init_kl_beta
Пример #9
0
if args.rnn:
    pol_net = PolNetLSTM(ob_space, ac_space, h_size=256, cell_size=256)
else:
    pol_net = PolNet(ob_space, ac_space)
if isinstance(ac_space, gym.spaces.Box):
    pol = GaussianPol(ob_space, ac_space, pol_net, args.rnn)
elif isinstance(ac_space, gym.spaces.Discrete):
    pol = CategoricalPol(ob_space, ac_space, pol_net, args.rnn)
elif isinstance(ac_space, gym.spaces.MultiDiscrete):
    pol = MultiCategoricalPol(ob_space, ac_space, pol_net, args.rnn)
else:
    raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

vf_net = VNet(ob_space)
vf = DeterministicSVfunc(ob_space, vf_net, args.rnn)

discrim_net = DiscrimNet(ob_space,
                         ac_space,
                         h1=args.discrim_h1,
                         h2=args.discrim_h2)
discrim = DeterministicSAVfunc(ob_space, ac_space, discrim_net)

sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)
optim_discrim = torch.optim.Adam(discrim_net.parameters(), args.discrim_lr)

with open(os.path.join(args.expert_dir, args.expert_fname), 'rb') as f:
    expert_epis = pickle.load(f)
Пример #10
0
                     h_size=args.h_size,
                     cell_size=args.cell_size)

pol = MultiCategoricalPol(observation_space,
                          action_space,
                          pol_net,
                          True,
                          data_parallel=args.data_parallel,
                          parallel_dim=1)

vf_net = VNetLSTM(observation_space,
                  h_size=args.h_size,
                  cell_size=args.cell_size)
vf = DeterministicSVfunc(observation_space,
                         vf_net,
                         True,
                         data_parallel=args.data_parallel,
                         parallel_dim=1)

sampler1 = EpiSampler(env1,
                      pol,
                      num_parallel=args.num_parallel,
                      seed=args.seed)
sampler2 = EpiSampler(env2,
                      pol,
                      num_parallel=args.num_parallel,
                      seed=args.seed)

optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)