コード例 #1
0
ファイル: test_algos.py プロジェクト: takerfume/machina
    def test_learning(self):
        ob_space = self.env.real_observation_space
        skill_space = self.env.skill_space
        ob_skill_space = self.env.observation_space
        ac_space = self.env.action_space
        ob_dim = ob_skill_space.shape[0] - 4
        f_dim = ob_dim
        def discrim_f(x): return x

        pol_net = PolNet(ob_skill_space, ac_space)
        pol = GaussianPol(ob_skill_space, ac_space, pol_net)
        qf_net1 = QNet(ob_skill_space, ac_space)
        qf1 = DeterministicSAVfunc(ob_skill_space, ac_space, qf_net1)
        targ_qf_net1 = QNet(ob_skill_space, ac_space)
        targ_qf_net1.load_state_dict(qf_net1.state_dict())
        targ_qf1 = DeterministicSAVfunc(ob_skill_space, ac_space, targ_qf_net1)
        qf_net2 = QNet(ob_skill_space, ac_space)
        qf2 = DeterministicSAVfunc(ob_skill_space, ac_space, qf_net2)
        targ_qf_net2 = QNet(ob_skill_space, ac_space)
        targ_qf_net2.load_state_dict(qf_net2.state_dict())
        targ_qf2 = DeterministicSAVfunc(ob_skill_space, ac_space, targ_qf_net2)
        qfs = [qf1, qf2]
        targ_qfs = [targ_qf1, targ_qf2]
        log_alpha = nn.Parameter(torch.ones(()))

        high = np.array([np.finfo(np.float32).max]*f_dim)
        f_space = gym.spaces.Box(-high, high, dtype=np.float32)
        discrim_net = DiaynDiscrimNet(
            f_space, skill_space, h_size=100, discrim_f=discrim_f)
        discrim = DeterministicSVfunc(f_space, discrim_net)

        optim_pol = torch.optim.Adam(pol_net.parameters(), 1e-4)
        optim_qf1 = torch.optim.Adam(qf_net1.parameters(), 3e-4)
        optim_qf2 = torch.optim.Adam(qf_net2.parameters(), 3e-4)
        optim_qfs = [optim_qf1, optim_qf2]
        optim_alpha = torch.optim.Adam([log_alpha], 1e-4)
        optim_discrim = torch.optim.SGD(discrim.parameters(),
                                        lr=0.001, momentum=0.9)

        off_traj = Traj()
        sampler = EpiSampler(self.env, pol, num_parallel=1)

        epis = sampler.sample(pol, max_steps=200)
        on_traj = Traj()
        on_traj.add_epis(epis)
        on_traj = ef.add_next_obs(on_traj)
        on_traj = ef.compute_diayn_rews(
            on_traj, lambda x: diayn_sac.calc_rewards(x, 4, discrim))
        on_traj.register_epis()
        off_traj.add_traj(on_traj)
        step = on_traj.num_step
        log_alpha = nn.Parameter(np.log(0.1)*torch.ones(()))  # fix alpha
        result_dict = diayn_sac.train(
            off_traj, pol, qfs, targ_qfs, log_alpha,
            optim_pol, optim_qfs, optim_alpha,
            step, 128, 5e-3, 0.99, 1, discrim, 4, True)
        discrim_losses = diayn.train(
            discrim, optim_discrim, on_traj, 32, 100, 4)

        del sampler
コード例 #2
0
    raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

if args.rnn:
    vf_net = VNetLSTM(observation_space, h_size=256, cell_size=256)
else:
    vf_net = VNet(observation_space)
vf = DeterministicSVfunc(observation_space, vf_net, args.rnn)

if rank == 0:
    sampler = EpiSampler(env,
                         pol,
                         num_parallel=args.num_parallel,
                         seed=args.seed)

optim_pol = torch.optim.Adam(pol.parameters(), args.pol_lr)
optim_vf = torch.optim.Adam(vf.parameters(), args.vf_lr)

ddp_pol, optim_pol = make_model_distributed(pol,
                                            optim_pol,
                                            args.use_apex,
                                            args.apex_opt_level,
                                            args.apex_keep_batchnorm_fp32,
                                            args.apex_sync_bn,
                                            args.apex_loss_scale,
                                            device_ids=[args.local_rank],
                                            output_device=args.local_rank)
ddp_vf, optim_vf = make_model_distributed(vf,
                                          optim_vf,
                                          args.use_apex,
                                          args.apex_opt_level,
                                          args.apex_keep_batchnorm_fp32,
コード例 #3
0
ファイル: run_diayn.py プロジェクト: yumion/machina
high = np.array([np.finfo(np.float32).max]*f_dim)
f_space = gym.spaces.Box(-high, high, dtype=np.float32)
discrim_net = DiaynDiscrimNet(
    f_space, skill_space, h_size=args.discrim_h_size, discrim_f=discrim_f).to(device)

discrim = DeterministicSVfunc(
    f_space, discrim_net, rnn=False, data_parallel=False, parallel_dim=0)


# set optimizer to both models
optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_qf1 = torch.optim.Adam(qf_net1.parameters(), args.qf_lr)
optim_qf2 = torch.optim.Adam(qf_net2.parameters(), args.qf_lr)
optim_qfs = [optim_qf1, optim_qf2]
optim_alpha = torch.optim.Adam([log_alpha], args.pol_lr)
optim_discrim = torch.optim.SGD(discrim.parameters(
), lr=args.discrim_lr, momentum=args.discrim_momentum)

off_traj = Traj()
sampler = EpiSampler(
    env, pol, num_parallel=args.num_parallel, seed=args.seed)

if not os.path.exists(args.log):
    os.mkdir(args.log)
    os.mkdir(args.log+'/models')
score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# counter and record for loop
total_epi = 0
コード例 #4
0
ファイル: run_diayn.py プロジェクト: takerfume/machina
                              h_size=args.discrim_h_size,
                              discrim_f=discrim_f).to(device)

discrim = DeterministicSVfunc(f_space,
                              discrim_net,
                              rnn=False,
                              data_parallel=False,
                              parallel_dim=0)

# set optimizer to both models
optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
optim_qf1 = torch.optim.Adam(qf_net1.parameters(), args.qf_lr)
optim_qf2 = torch.optim.Adam(qf_net2.parameters(), args.qf_lr)
optim_qfs = [optim_qf1, optim_qf2]
optim_alpha = torch.optim.Adam([log_alpha], args.pol_lr)
optim_discrim = torch.optim.SGD(discrim.parameters(),
                                lr=args.discrim_lr,
                                momentum=args.discrim_momentum)

off_traj = Traj()
sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

if not os.path.exists(args.log):
    os.mkdir(args.log)
    os.mkdir(args.log + '/models')
score_file = os.path.join(args.log, 'progress.csv')
logger.add_tabular_output(score_file)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# counter and record for loop