Esempio n. 1
0
        # on-policyのサンプリング
        print('on-policy')
        on_traj = Traj(traj_device='cpu')
        on_traj.add_epis(epis)
        on_traj = epi_functional.add_next_obs(on_traj)
        on_traj.register_epis()
        off_traj.add_traj(on_traj)  # off-policyに加える

        # episodeとstepのカウント
        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step
        epoch = step

        if args.data_parallel:
            qf.dp_run = True
            lagged_qf.dp_run = True
            targ_qf1.dp_run = True
            targ_qf2.dp_run = True
        # train
        print('train')
        result_dict = qtopt.train(off_traj,
                                  qf,
                                  lagged_qf,
                                  targ_qf1,
                                  targ_qf2,
                                  optim_qf,
                                  epoch,
                                  args.batch_size,
                                  args.tau,
                                  args.gamma,
Esempio n. 2
0
        epis = sampler.sample(pol, max_steps=args.max_steps_per_iter)
    with measure('train'):
        agent_traj = Traj()
        agent_traj.add_epis(epis)
        agent_traj = ef.compute_pseudo_rews(agent_traj, discrim)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, args.gamma)
        agent_traj = ef.compute_advs(agent_traj, args.gamma, args.lam)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        if args.data_parallel:
            pol.dp_run = True
            vf.dp_run = True
            discrim.dp_run = True

        if args.rl_type == 'trpo':
            result_dict = gail.train(
                agent_traj,
                expert_traj,
                pol,
                vf,
                discrim,
                optim_vf,
                optim_discrim,
                rl_type=args.rl_type,
                epoch=args.epoch_per_iter,
                batch_size=args.batch_size
                if not args.rnn else args.rnn_batch_size,
                discrim_batch_size=args.discrim_batch_size,
Esempio n. 3
0
    with measure('train'):
        on_traj = Traj()
        on_traj.add_epis(epis)

        on_traj = ef.add_next_obs(on_traj)
        on_traj = ef.compute_vs(on_traj, vf)
        on_traj = ef.compute_rets(on_traj, args.gamma)
        on_traj = ef.compute_advs(on_traj, args.gamma, args.lam)
        on_traj = ef.centerize_advs(on_traj)
        on_traj = ef.compute_h_masks(on_traj)
        on_traj.register_epis()

        if args.data_parallel:
            pol.dp_run = True
            vf.dp_run = True
            qf.dp_run = True

        result_dict1 = ppo_clip.train(traj=on_traj,
                                      pol=pol,
                                      vf=vf,
                                      clip_param=args.clip_param,
                                      optim_pol=optim_pol,
                                      optim_vf=optim_vf,
                                      epoch=args.epoch_per_iter,
                                      batch_size=args.batch_size,
                                      max_grad_norm=args.max_grad_norm)

        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step
Esempio n. 4
0
            state_only=True if args.rew_type == 'rew' else False)
        agent_traj = ef.compute_vs(agent_traj, vf)
        agent_traj = ef.compute_rets(agent_traj, args.gamma)
        agent_traj = ef.compute_advs(agent_traj, args.gamma, args.lam)
        agent_traj = ef.centerize_advs(agent_traj)
        agent_traj = ef.compute_h_masks(agent_traj)
        agent_traj.register_epis()

        if args.data_parallel:
            pol.dp_run = True
            vf.dp_run = True
            if args.rew_type == 'rew':
                rewf.dp_run = True
                shaping_vf.dp_run = True
            elif args.rew_type == 'adv':
                advf.dp_run = True

        if args.rl_type == 'trpo':
            result_dict = airl.train(
                agent_traj,
                expert_traj,
                pol,
                vf,
                optim_vf,
                optim_discrim,
                rewf=rewf,
                shaping_vf=shaping_vf,
                advf=advf,
                rl_type=args.rl_type,
                epoch=args.epoch_per_iter,
                batch_size=args.batch_size,
Esempio n. 5
0
        on_traj = ef.add_next_obs(on_traj)
        max_pri = on_traj.get_max_pri()
        on_traj = ef.set_all_pris(on_traj, max_pri)
        on_traj.register_epis()

        off_traj.add_traj(on_traj)

        total_epi += on_traj.num_epi
        step = on_traj.num_step
        total_step += step

        if args.data_parallel:
            pol.dp_run = True
            targ_pol.dp_run = True
            qf.dp_run = True
            targ_qf.dp_run = True

        result_dict = prioritized_ddpg.train(
            off_traj,
            pol, targ_pol, qf, targ_qf,
            optim_pol, optim_qf, step, args.batch_size,
            args.tau, args.gamma
        )

        if args.data_parallel:
            pol.dp_run = False
            targ_pol.dp_run = False
            qf.dp_run = False
            targ_qf.dp_run = False