コード例 #1
0
ファイル: test_algos.py プロジェクト: yumion/machina
    def test_learning(self):
        pol_net = PolNet(self.env.observation_space,
                         self.env.action_space,
                         h1=32,
                         h2=32,
                         deterministic=True)
        noise = OUActionNoise(self.env.action_space)
        pol = DeterministicActionNoisePol(self.env.observation_space,
                                          self.env.action_space, pol_net,
                                          noise)

        targ_pol_net = PolNet(self.env.observation_space,
                              self.env.action_space,
                              32,
                              32,
                              deterministic=True)
        targ_pol_net.load_state_dict(pol_net.state_dict())
        targ_noise = OUActionNoise(self.env.action_space)
        targ_pol = DeterministicActionNoisePol(self.env.observation_space,
                                               self.env.action_space,
                                               targ_pol_net, targ_noise)

        qf_net = QNet(self.env.observation_space,
                      self.env.action_space,
                      h1=32,
                      h2=32)
        qf = DeterministicSAVfunc(self.env.observation_space,
                                  self.env.action_space, qf_net)

        targ_qf_net = QNet(self.env.observation_space, self.env.action_space,
                           32, 32)
        targ_qf_net.load_state_dict(targ_qf_net.state_dict())
        targ_qf = DeterministicSAVfunc(self.env.observation_space,
                                       self.env.action_space, targ_qf_net)

        sampler = EpiSampler(self.env, pol, num_parallel=1)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 3e-4)
        optim_qf = torch.optim.Adam(qf_net.parameters(), 3e-4)

        epis = sampler.sample(pol, max_steps=32)

        traj = Traj()
        traj.add_epis(epis)

        traj = ef.add_next_obs(traj)
        traj.register_epis()

        result_dict = ddpg.train(traj, pol, targ_pol, qf, targ_qf, optim_pol,
                                 optim_qf, 1, 32, 0.01, 0.9)

        del sampler
コード例 #2
0
ファイル: run_ddpg.py プロジェクト: iory/machina
device_name = 'cpu' if args.cuda < 0 else "cuda:{}".format(args.cuda)
device = torch.device(device_name)
set_device(device)

score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)

env = GymEnv(args.env_name, log_dir=os.path.join(
    args.log, 'movie'), record_video=args.record)
env.env.seed(args.seed)

ob_space = env.observation_space
ac_space = env.action_space

pol_net = PolNet(ob_space, ac_space, args.h1, args.h2, deterministic=True)
noise = OUActionNoise(ac_space)
pol = DeterministicActionNoisePol(ob_space, ac_space, pol_net, noise)

targ_pol_net = PolNet(ob_space, ac_space, args.h1, args.h2, deterministic=True)
targ_pol_net.load_state_dict(pol_net.state_dict())
targ_noise = OUActionNoise(ac_space.shape)
targ_pol = DeterministicActionNoisePol(
    ob_space, ac_space, targ_pol_net, targ_noise)

qf_net = QNet(ob_space, ac_space, args.h1, args.h2)
qf = DeterministicSAVfunc(ob_space, ac_space, qf_net)

targ_qf_net = QNet(ob_space, ac_space, args.h1, args.h2)
targ_qf_net.load_state_dict(qf_net.state_dict())
targ_qf = DeterministicSAVfunc(ob_space, ac_space, targ_qf_net)
コード例 #3
0
             record_video=True,
             video_schedule=lambda x: True)
env.env.seed(args.seed)
if args.c2d:
    env = C2DEnv(env)

observation_space = env.observation_space
action_space = env.action_space

if args.ddpg:
    pol_net = PolNet(observation_space,
                     action_space,
                     args.pol_h1,
                     args.pol_h2,
                     deterministic=True)
    noise = OUActionNoise(action_space)
    pol = DeterministicActionNoisePol(observation_space, action_space, pol_net,
                                      noise)
else:
    if args.rnn:
        pol_net = PolNetLSTM(observation_space,
                             action_space,
                             h_size=256,
                             cell_size=256)
    else:
        pol_net = PolNet(observation_space, action_space)
    if isinstance(action_space, gym.spaces.Box):
        pol = GaussianPol(observation_space, action_space, pol_net, args.rnn)
    elif isinstance(action_space, gym.spaces.Discrete):
        pol = CategoricalPol(observation_space, action_space, pol_net,
                             args.rnn)
コード例 #4
0
ファイル: make_expert_epis.py プロジェクト: takerfume/machina
             log_dir=os.path.join(args.pol_dir, 'movie'),
             record_video=args.record)
env.env.seed(args.seed)
if args.c2d:
    env = C2DEnv(env)

ob_space = env.observation_space
ac_space = env.action_space

if args.ddpg:
    pol_net = PolNet(ob_space,
                     ac_space,
                     args.pol_h1,
                     args.pol_h2,
                     deterministic=True)
    noise = OUActionNoise(ac_space.shape)
    pol = DeterministicActionNoisePol(ob_space, ac_space, pol_net, noise)
else:
    if args.rnn:
        pol_net = PolNetLSTM(ob_space, ac_space, h_size=256, cell_size=256)
    else:
        pol_net = PolNet(ob_space, ac_space)
    if isinstance(ac_space, gym.spaces.Box):
        pol = GaussianPol(ob_space, ac_space, pol_net, args.rnn)
    elif isinstance(ac_space, gym.spaces.Discrete):
        pol = CategoricalPol(ob_space, ac_space, pol_net, args.rnn)
    elif isinstance(ac_space, gym.spaces.MultiDiscrete):
        pol = MultiCategoricalPol(ob_space, ac_space, pol_net, args.rnn)
    else:
        raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')
コード例 #5
0
ファイル: run_ddpg.py プロジェクト: yumion/machina
device = torch.device(device_name)
set_device(device)

score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)

env = GymEnv(args.env_name, log_dir=os.path.join(
    args.log, 'movie'), record_video=args.record)
env.env.seed(args.seed)

observation_space = env.observation_space
action_space = env.action_space

pol_net = PolNet(observation_space, action_space,
                 args.h1, args.h2, deterministic=True)
noise = OUActionNoise(action_space)
pol = DeterministicActionNoisePol(
    observation_space, action_space, pol_net, noise)

targ_pol_net = PolNet(observation_space, action_space,
                      args.h1, args.h2, deterministic=True)
targ_pol_net.load_state_dict(pol_net.state_dict())
targ_noise = OUActionNoise(action_space)
targ_pol = DeterministicActionNoisePol(
    observation_space, action_space, targ_pol_net, targ_noise)

qf_net = QNet(observation_space, action_space, args.h1, args.h2)
qf = DeterministicSAVfunc(observation_space, action_space, qf_net)

targ_qf_net = QNet(observation_space, action_space, args.h1, args.h2)
targ_qf_net.load_state_dict(qf_net.state_dict())