Пример #1
0
            targ_qf1.dp_run = False
            targ_qf2.dp_run = False

    total_grad_step += epoch
    if total_grad_step >= args.lag * num_update_lagged:  # 6000stepsごとにlagged netを更新
        logger.log('Updated lagged qf!!')
        lagged_qf_net.load_state_dict(qf_net.state_dict())
        num_update_lagged += 1

    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    # logを保存
    logger.record_results(args.log,
                          result_dict,
                          score_file,
                          total_epi,
                          step,
                          total_step,
                          rewards,
                          plot_title=args.env_name)

    if mean_rew > max_rew:  # 報酬の最大値が更新されたら保存
        # policy
        torch.save(pol.state_dict(),
                   os.path.join(args.log, 'models', 'pol_max.pkl'))
        # Q関数
        torch.save(qf.state_dict(),
                   os.path.join(args.log, 'models', 'qf_max.pkl'))
        # target Q theta1
        torch.save(targ_qf1.state_dict(),
                   os.path.join(args.log, 'models', 'targ_qf1_max.pkl'))
        # target Q theta 2
Пример #2
0
            traj=traj,
            student_pol=s_pol,
            teacher_pol=t_pol,
            student_optim=optim_pol,
            epoch=args.epoch_per_iter,
            batchsize=args.batch_size)

    logger.log('Testing Student-policy')
    with measure('sample'):
        epis_measure = student_sampler.sample(
            s_pol, max_epis=args.max_epis_per_iter)

    with measure('measure'):
        traj_measure = Traj()
        traj_measure.add_epis(epis_measure)
        traj_measure = ef.compute_h_masks(traj_measure)
        traj_measure.register_epis()

    total_epi += traj_measure.num_epi
    step = traj_measure.num_step
    total_step += step
    rewards = [np.sum(epi['rews']) for epi in epis_measure]
    mean_rew = np.mean(rewards)
    logger.record_results(args.log, result_dict, score_file,
                          total_epi, step, total_epi, rewards,
                          plot_title='Policy Distillation')

    del traj
    del traj_measure
del sampler
Пример #3
0
def main(args):
    init_ray(args.num_cpus, args.num_gpus, args.ray_redis_address)

    if not os.path.exists(args.log):
        os.makedirs(args.log)
    if not os.path.exists(os.path.join(args.log, 'models')):
        os.mkdir(os.path.join(args.log, 'models'))
    score_file = os.path.join(args.log, 'progress.csv')
    logger.add_tabular_output(score_file)
    logger.add_tensorboard_output(args.log)
    with open(os.path.join(args.log, 'args.json'), 'w') as f:
        json.dump(vars(args), f)
    pprint(vars(args))

    # when doing the distributed training, disable video recordings
    env = GymEnv(args.env_name)
    env.env.seed(args.seed)
    if args.c2d:
        env = C2DEnv(env)

    observation_space = env.observation_space
    action_space = env.action_space
    pol_net = PolNet(observation_space, action_space)
    rnn = False
    # pol_net = PolNetLSTM(observation_space, action_space)
    # rnn = True
    if isinstance(action_space, gym.spaces.Box):
        pol = GaussianPol(observation_space, action_space, pol_net, rnn=rnn)
    elif isinstance(action_space, gym.spaces.Discrete):
        pol = CategoricalPol(observation_space, action_space, pol_net)
    elif isinstance(action_space, gym.spaces.MultiDiscrete):
        pol = MultiCategoricalPol(observation_space, action_space, pol_net)
    else:
        raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

    vf_net = VNet(observation_space)
    vf = DeterministicSVfunc(observation_space, vf_net)

    trainer = TrainManager(Trainer,
                           args.num_trainer,
                           args.master_address,
                           args=args,
                           vf=vf,
                           pol=pol)
    sampler = EpiSampler(env, pol, args.num_parallel, seed=args.seed)

    total_epi = 0
    total_step = 0
    max_rew = -1e6
    start_time = time.time()

    while args.max_epis > total_epi:

        with measure('sample'):
            sampler.set_pol_state(trainer.get_state("pol"))
            epis = sampler.sample(max_steps=args.max_steps_per_iter)

        with measure('train'):
            result_dict = trainer.train(epis=epis)

        step = result_dict["traj_num_step"]
        total_step += step
        total_epi += result_dict["traj_num_epi"]

        rewards = [np.sum(epi['rews']) for epi in epis]
        mean_rew = np.mean(rewards)
        elapsed_time = time.time() - start_time
        logger.record_tabular('ElapsedTime', elapsed_time)
        logger.record_results(args.log,
                              result_dict,
                              score_file,
                              total_epi,
                              step,
                              total_step,
                              rewards,
                              plot_title=args.env_name)

        with measure('save'):
            pol_state = trainer.get_state("pol")
            vf_state = trainer.get_state("vf")
            optim_pol_state = trainer.get_state("optim_pol")
            optim_vf_state = trainer.get_state("optim_vf")

            torch.save(pol_state,
                       os.path.join(args.log, 'models', 'pol_last.pkl'))
            torch.save(vf_state, os.path.join(args.log, 'models',
                                              'vf_last.pkl'))
            torch.save(optim_pol_state,
                       os.path.join(args.log, 'models', 'optim_pol_last.pkl'))
            torch.save(optim_vf_state,
                       os.path.join(args.log, 'models', 'optim_vf_last.pkl'))

            if mean_rew > max_rew:
                torch.save(pol_state,
                           os.path.join(args.log, 'models', 'pol_max.pkl'))
                torch.save(vf_state,
                           os.path.join(args.log, 'models', 'vf_max.pkl'))
                torch.save(
                    optim_pol_state,
                    os.path.join(args.log, 'models', 'optim_pol_max.pkl'))
                torch.save(
                    optim_vf_state,
                    os.path.join(args.log, 'models', 'optim_vf_max.pkl'))
                max_rew = mean_rew
    del sampler
    del trainer
Пример #4
0
        if args.data_parallel:
            pol.dp_run = False
            vf.dp_run = False

    total_epi += traj1.num_epi
    step = traj1.num_step
    total_step += step
    rewards1 = [np.sum(epi['rews']) for epi in epis1]
    rewards2 = [np.sum(epi['rews']) for epi in epis2]
    mean_rew = np.mean(rewards1 + rewards2)
    logger.record_tabular_misc_stat('Reward1', rewards1)
    logger.record_tabular_misc_stat('Reward2', rewards2)
    logger.record_results(args.log,
                          result_dict,
                          score_file,
                          total_epi,
                          step,
                          total_step,
                          rewards1 + rewards2,
                          plot_title='humanoid')

    if mean_rew > max_rew:
        torch.save(pol.state_dict(),
                   os.path.join(args.log, 'models', 'pol_max.pkl'))
        torch.save(vf.state_dict(),
                   os.path.join(args.log, 'models', 'vf_max.pkl'))
        torch.save(optim_pol.state_dict(),
                   os.path.join(args.log, 'models', 'optim_pol_max.pkl'))
        torch.save(optim_vf.state_dict(),
                   os.path.join(args.log, 'models', 'optim_vf_max.pkl'))
        max_rew = mean_rew
Пример #5
0
    def train(self):
        args = self.args

        # TODO: cuda seems to be broken, I don't care about it right now
        # if args.cuda:
        #     # current_obs = current_obs.cuda()
        #     rollouts.cuda()

        self.train_start_time = time.time()
        total_epi = 0
        total_step = 0
        max_rew = -1e6
        sampler = None

        score_file = os.path.join(self.logger.get_logdir(), "progress.csv")
        logger.add_tabular_output(score_file)

        num_total_frames = args.num_total_frames

        mirror_function = None
        if args.mirror_tuples and hasattr(self.env.unwrapped,
                                          "mirror_indices"):
            mirror_function = get_mirror_function(
                **self.env.unwrapped.mirror_indices)
            num_total_frames *= 2
            if not args.tanh_finish:
                warnings.warn(
                    "When `mirror_tuples` is `True`,"
                    " `tanh_finish` should be set to `True` as well."
                    " Otherwise there is a chance of the training blowing up.")

        while num_total_frames > total_step:
            # setup the correct curriculum learning environment/parameters
            new_curriculum = self.curriculum_handler(total_step /
                                                     args.num_total_frames)

            if total_step == 0 or new_curriculum:
                if sampler is not None:
                    del sampler
                sampler = EpiSampler(
                    self.env,
                    self.pol,
                    num_parallel=self.args.num_processes,
                    seed=self.args.seed + total_step,  # TODO: better fix?
                )

            with measure("sample"):
                epis = sampler.sample(self.pol,
                                      max_steps=args.num_steps *
                                      args.num_processes)

            with measure("train"):
                with measure("epis"):
                    traj = Traj()
                    traj.add_epis(epis)

                    traj = ef.compute_vs(traj, self.vf)
                    traj = ef.compute_rets(traj, args.decay_gamma)
                    traj = ef.compute_advs(traj, args.decay_gamma,
                                           args.gae_lambda)
                    traj = ef.centerize_advs(traj)
                    traj = ef.compute_h_masks(traj)
                    traj.register_epis()

                    if mirror_function:
                        traj.add_traj(mirror_function(traj))

                # if args.data_parallel:
                #     self.pol.dp_run = True
                #     self.vf.dp_run = True

                result_dict = ppo_clip.train(
                    traj=traj,
                    pol=self.pol,
                    vf=self.vf,
                    clip_param=args.clip_eps,
                    optim_pol=self.optim_pol,
                    optim_vf=self.optim_vf,
                    epoch=args.epoch_per_iter,
                    batch_size=args.batch_size
                    if not args.rnn else args.rnn_batch_size,
                    max_grad_norm=args.max_grad_norm,
                )

                # if args.data_parallel:
                #     self.pol.dp_run = False
                #     self.vf.dp_run = False

            ## append the metrics to the `results_dict` (reported in the progress.csv)
            result_dict.update(self.get_extra_metrics(epis))

            total_epi += traj.num_epi
            step = traj.num_step
            total_step += step
            rewards = [np.sum(epi["rews"]) for epi in epis]
            mean_rew = np.mean(rewards)
            logger.record_results(
                self.logger.get_logdir(),
                result_dict,
                score_file,
                total_epi,
                step,
                total_step,
                rewards,
                plot_title=args.env,
            )

            if mean_rew > max_rew:
                self.save_models("max")
                max_rew = mean_rew

            self.save_models("last")

            self.scheduler_pol.step()
            self.scheduler_vf.step()

            del traj
                                     clip_param=clip_param,
                                     optim_pol=optim_pol,
                                     optim_vf=optim_vf,
                                     epoch=epoch_per_iter,
                                     batch_size=batch_size,
                                     max_grad_norm=max_grad_norm)
    # update counter and record
    total_epi += traj.num_epi
    step = traj.num_step
    total_step += step
    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(log_dir_name,
                          result_dict,
                          score_file,
                          total_epi,
                          step,
                          total_step,
                          rewards,
                          plot_title=env_name)
    if mean_rew > max_rew:
        torch.save(pol.state_dict(),
                   os.path.join(log_dir_name, 'models', 'pol_max.pkl'))
        torch.save(vf.state_dict(),
                   os.path.join(log_dir_name, 'models', 'vf_max.pkl'))
        torch.save(optim_pol.state_dict(),
                   os.path.join(log_dir_name, 'models', 'optim_pol_max.pkl'))
        torch.save(optim_vf.state_dict(),
                   os.path.join(log_dir_name, 'models', 'optim_vf_max.pkl'))
        max_rew = mean_rew

    torch.save(pol.state_dict(),