示例#1
0
文件: play.py 项目: Baichenjia/BHER
def main(policy_file, seed, n_test_rollouts, render):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name]
                      )  # merge env-specific parameters in
    params['env_name'] = env_name
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'compute_Q': True,
        'rollout_batch_size': 1,
        'render': bool(render),
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    evaluator = RolloutWorker(params['make_env'], policy, dims, logger,
                              **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
示例#2
0
文件: train.py 项目: Baichenjia/BHER
def train(policy, rollout_worker, evaluator, n_epochs, n_test_rollouts,
          n_cycles, n_batches, policy_save_interval, save_policies, **kwargs):
    rank = MPI.COMM_WORLD.Get_rank()

    latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
    best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
    periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')

    logger.info("Training...")
    # logger.info("Epoch -1 | Finish Time :{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
    starttime = datetime.datetime.now()
    best_success_rate = -1
    for epoch in range(n_epochs):
        # train
        rollout_worker.clear_history()
        for _ in range(n_cycles):
            episode = rollout_worker.generate_rollouts()
            # print(episode['info_is_success'])
            policy.store_episode(episode)
            for _ in range(n_batches):
                policy.train()
            policy.update_target_net()

        # test
        evaluator.clear_history()
        for _ in range(n_test_rollouts):
            evaluator.generate_rollouts()

        # logger.info("Epoch: {} | Finish Time :{}".format(epoch, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))

        # record logs
        logger.record_tabular('epoch', epoch)
        for key, val in evaluator.logs('test'):
            logger.record_tabular(key, mpi_average(val))
        for key, val in rollout_worker.logs('train'):
            logger.record_tabular(key, mpi_average(val))
        for key, val in policy.logs():
            logger.record_tabular(key, mpi_average(val))

        endtime = datetime.datetime.now()
        logger.record_tabular('time',
                              str(endtime - starttime).replace(',', '-'))

        if rank == 0:
            logger.dump_tabular()

        # save the policy if it's better than the previous ones
        success_rate = mpi_average(evaluator.current_success_rate())
        if rank == 0 and success_rate >= best_success_rate and save_policies:
            best_success_rate = success_rate
            logger.info(
                'New best success rate: {}. Saving policy to {} ...'.format(
                    best_success_rate, best_policy_path))
            evaluator.save_policy(best_policy_path)
            evaluator.save_policy(latest_policy_path)
        if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
            policy_path = periodic_policy_path.format(epoch)
            logger.info('Saving periodic policy to {} ...'.format(policy_path))
            evaluator.save_policy(policy_path)

        # make sure that different threads have different seeds
        local_uniform = np.random.uniform(size=(1, ))
        root_uniform = local_uniform.copy()
        MPI.COMM_WORLD.Bcast(root_uniform, root=0)
        if rank != 0:
            assert local_uniform[0] != root_uniform[0]
示例#3
0
文件: train.py 项目: Baichenjia/BHER
def train(policy, rollout_worker, evaluator, n_epochs, n_test_rollouts,
          n_cycles, n_batches, policy_save_interval, save_policies, **kwargs):
    rank = MPI.COMM_WORLD.Get_rank()

    latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
    best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
    periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')
    r_mean_logdir = os.path.join(logger.get_dir(), 'total_rbias_mean.npy')
    r_std_logdir = os.path.join(logger.get_dir(), 'total_rbias_std.npy')

    logger.info("Training...")
    starttime = datetime.datetime.now()
    best_success_rate = -1
    for epoch in range(n_epochs):
        policy.epcoch_num = epoch
        # train
        rollout_worker.clear_history()
        for _ in range(n_cycles):
            episode = rollout_worker.generate_rollouts()
            policy.store_episode(episode)
            for _ in range(n_batches):
                policy.train()
            policy.update_target_net()
        # lky
        if rank == 0:
            policy.isPlot = False
            policy.picdir = os.path.join(logger.get_dir(),
                                         'rew_epoch_' + str(epoch) + '.pdf')
            policy.rewdir = os.path.join(logger.get_dir(),
                                         'rew_epoch_' + str(epoch) + '.npy')

        # test
        evaluator.clear_history()
        for _ in range(n_test_rollouts):
            evaluator.generate_rollouts()

        # record logs
        logger.record_tabular('epoch', epoch)
        for key, val in evaluator.logs('test'):
            logger.record_tabular(key, mpi_average(val))
        for key, val in rollout_worker.logs('train'):
            logger.record_tabular(key, mpi_average(val))
        for key, val in policy.logs():
            logger.record_tabular(key, mpi_average(val))

        endtime = datetime.datetime.now()
        logger.record_tabular('time',
                              str(endtime - starttime).replace(',', '-'))

        if rank == 0:
            logger.dump_tabular()

        # save reward
        if rank == 0:
            with open(r_mean_logdir, "wb") as fp:
                pickle.dump(policy.total_epoch_r_mean_bias, fp)
            with open(r_std_logdir, "wb") as fp:
                pickle.dump(policy.total_epoch_r_std_bias, fp)

        # save the policy if it's better than the previous ones
        success_rate = mpi_average(evaluator.current_success_rate())
        if rank == 0 and success_rate >= best_success_rate and save_policies:
            best_success_rate = success_rate
            logger.info(
                'New best success rate: {}. Saving policy to {} ...'.format(
                    best_success_rate, best_policy_path))
            evaluator.save_policy(best_policy_path)
            evaluator.save_policy(latest_policy_path)
        if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
            policy_path = periodic_policy_path.format(epoch)
            logger.info('Saving periodic policy to {} ...'.format(policy_path))
            evaluator.save_policy(policy_path)

        # make sure that different threads have different seeds
        local_uniform = np.random.uniform(size=(1, ))
        root_uniform = local_uniform.copy()
        MPI.COMM_WORLD.Bcast(root_uniform, root=0)
        if rank != 0:
            assert local_uniform[0] != root_uniform[0]
示例#4
0
def train(policy, rollout_worker, evaluator, n_epochs, n_test_rollouts,
          n_cycles, n_batches, policy_save_interval, save_policies, num_cpu,
          dump_buffer, rank_method, fit_interval, prioritization, **kwargs):
    rank = MPI.COMM_WORLD.Get_rank()

    latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
    best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
    periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')

    logger.info("Training...")
    best_success_rate = -1
    t = 1
    starttime = datetime.datetime.now()
    for epoch in range(n_epochs):
        # train
        rollout_worker.clear_history()
        for cycle in range(n_cycles):
            episode = rollout_worker.generate_rollouts()
            if (cycle % fit_interval
                    == 0) and (not cycle == 0) or (cycle == n_cycles - 1):
                if prioritization == 'entropy':
                    policy.fit_density_model()
            policy.store_episode(episode, dump_buffer, rank_method, epoch)
            for batch in range(n_batches):
                t = ((epoch * n_cycles * n_batches) +
                     (cycle * n_batches) + batch) * num_cpu
                policy.train(t, dump_buffer)

            policy.update_target_net()

        # test
        evaluator.clear_history()
        for _ in range(n_test_rollouts):
            evaluator.generate_rollouts()

        # record logs
        logger.record_tabular('epoch', epoch)
        for key, val in evaluator.logs('test'):
            logger.record_tabular(key, mpi_average(val))
        for key, val in rollout_worker.logs('train'):
            logger.record_tabular(key, mpi_average(val))
        for key, val in policy.logs():
            logger.record_tabular(key, mpi_average(val))
        endtime = datetime.datetime.now()
        logger.record_tabular('time',
                              str(endtime - starttime).replace(',', '-'))

        if rank == 0:
            logger.dump_tabular()

            if dump_buffer:
                policy.dump_buffer(epoch)

        # save the policy if it's better than the previous ones
        success_rate = mpi_average(evaluator.current_success_rate())
        if rank == 0 and success_rate >= best_success_rate and save_policies:
            best_success_rate = success_rate
            logger.info(
                'New best success rate: {}. Saving policy to {} ...'.format(
                    best_success_rate, best_policy_path))
            evaluator.save_policy(best_policy_path)
            evaluator.save_policy(latest_policy_path)
        if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
            policy_path = periodic_policy_path.format(epoch)
            logger.info('Saving periodic policy to {} ...'.format(policy_path))
            evaluator.save_policy(policy_path)

        # make sure that different threads have different seeds
        local_uniform = np.random.uniform(size=(1, ))
        root_uniform = local_uniform.copy()
        MPI.COMM_WORLD.Bcast(root_uniform, root=0)
        if rank != 0:
            assert local_uniform[0] != root_uniform[0]