示例#1
0
def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    dataset, dev_dataset, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=FLAGS.batch_size,
        max_dev_size=FLAGS.max_dev_size,
        shuffle=False,
        repeat=False)
    environment_spec = specs.make_environment_spec(environment)
    """
      task_gamma_map = {
          'bsuite_catch': 0.25,
          'bsuite_mountain_car': 0.5,
          'bsuite_cartpole': 0.44,
      }
      gamma = FLAGS.gamma or task_gamma_map[problem_config['task_name']]
  """

    gamma = utils.get_median(problem_config['task_name'], environment_spec,
                             dataset)

    # Create the networks to optimize.
    value_func, instrumental_feature = kiv_batch.make_ope_networks(
        problem_config['task_name'],
        environment_spec,
        n_component=FLAGS.n_component,
        gamma=gamma)

    # Load pretrained target policy network.
    target_policy_net = utils.load_policy_net(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        environment_spec=environment_spec)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')
    logger = loggers.TerminalLogger('learner')

    # The learner updates the parameters (and initializes them).
    num_batches = 0
    for _ in dataset:
        num_batches += 1
    stage1_batch = num_batches // 2
    stage2_batch = num_batches - stage1_batch
    learner = kiv_batch.KIVLearner(value_func=value_func,
                                   instrumental_feature=instrumental_feature,
                                   policy_net=target_policy_net,
                                   discount=problem_config['discount'],
                                   stage1_reg=FLAGS.stage1_reg,
                                   stage2_reg=FLAGS.stage2_reg,
                                   stage1_batch=stage1_batch,
                                   stage2_batch=stage2_batch,
                                   dataset=dataset,
                                   valid_dataset=dev_dataset,
                                   counter=learner_counter,
                                   logger=logger,
                                   checkpoint=False)

    eval_counter = counting.Counter(counter, 'eval')
    eval_logger = loggers.TerminalLogger('eval')

    while True:
        results = {
            'gamma': gamma,
            'stage1_batch': stage1_batch,
            'stage2_batch': stage2_batch,
        }
        # Include learner results in eval results for ease of analysis.
        results.update(learner.step())
        results.update(
            utils.ope_evaluation(value_func=value_func,
                                 policy_net=target_policy_net,
                                 environment=environment,
                                 num_init_samples=FLAGS.evaluate_init_samples,
                                 discount=problem_config['discount'],
                                 counter=eval_counter))
        eval_logger.write(results)
        if learner.state['num_steps'] >= FLAGS.max_steps:
            break
示例#2
0
def main(_):
  problem_config = FLAGS.problem_config

  # Load the offline dataset and environment.
  dataset, dev_dataset, environment = utils.load_data_and_env(
      task_name=problem_config['task_name'],
      noise_level=problem_config['noise_level'],
      near_policy_dataset=problem_config['near_policy_dataset'],
      dataset_path=FLAGS.dataset_path,
      batch_size=FLAGS.batch_size,
      max_dev_size=FLAGS.max_dev_size)
  environment_spec = specs.make_environment_spec(environment)

  # Create the networks to optimize.
  value_func, instrumental_feature = dfiv.make_ope_networks(
      problem_config['task_name'], environment_spec,
      value_layer_sizes=FLAGS.value_layer_sizes,
      instrumental_layer_sizes=FLAGS.instrumental_layer_sizes)

  # Load pretrained target policy network.
  target_policy_net = utils.load_policy_net(
      task_name=problem_config['task_name'],
      noise_level=problem_config['noise_level'],
      near_policy_dataset=problem_config['near_policy_dataset'],
      dataset_path=FLAGS.dataset_path,
      environment_spec=environment_spec)

  counter = counting.Counter()
  learner_counter = counting.Counter(counter, prefix='learner')
  logger = loggers.TerminalLogger('learner')

  # The learner updates the parameters (and initializes them).
  learner_cls = dfiv.DFIVLearner
  if FLAGS.learner2:
    learner_cls = dfiv.DFIV2Learner
  learner = learner_cls(
      value_func=value_func,
      instrumental_feature=instrumental_feature,
      policy_net=target_policy_net,
      discount=problem_config['discount'],
      value_learning_rate=FLAGS.value_learning_rate,
      instrumental_learning_rate=FLAGS.instrumental_learning_rate,
      stage1_reg=FLAGS.stage1_reg,
      stage2_reg=FLAGS.stage2_reg,
      value_reg=FLAGS.value_reg,
      instrumental_reg=FLAGS.instrumental_reg,
      instrumental_iter=FLAGS.instrumental_iter,
      value_iter=FLAGS.value_iter,
      dataset=dataset,
      d_tm1_weight=FLAGS.d_tm1_weight,
      counter=learner_counter,
      logger=logger)

  eval_counter = counting.Counter(counter, 'eval')
  eval_logger = loggers.TerminalLogger('eval')

  while True:
    learner.step()
    steps = learner.state['num_steps'].numpy()

    if steps % FLAGS.evaluate_every == 0:
      eval_results = {}
      if dev_dataset is not None:
        eval_results = {'dev_mse': learner.cal_validation_err(dev_dataset)}
      eval_results.update(utils.ope_evaluation(
          value_func=value_func,
          policy_net=target_policy_net,
          environment=environment,
          num_init_samples=FLAGS.evaluate_init_samples,
          discount=problem_config['discount'],
          counter=eval_counter))
      eval_logger.write(eval_results)

    if steps >= FLAGS.max_steps:
      break
示例#3
0
def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    _, _, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=1)
    environment_spec = specs.make_environment_spec(environment)

    # Load pretrained target policy network.
    policy_net = utils.load_policy_net(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        environment_spec=environment_spec)

    actor = actors.FeedForwardActor(policy_network=policy_net)

    logger = loggers.TerminalLogger('ground_truth')

    discount = problem_config['discount']

    returns = []
    lengths = []

    t_start = time.time()
    timestep = environment.reset()
    actor.observe_first(timestep)
    cur_return = 0.
    cur_step = 0
    while len(returns) < FLAGS.num_episodes:

        action = actor.select_action(timestep.observation)
        timestep = environment.step(action)
        # Have the agent observe the timestep and let the actor update itself.
        actor.observe(action, next_timestep=timestep)

        cur_return += pow(discount, cur_step) * timestep.reward
        cur_step += 1

        if timestep.last():
            # Append return of the current episode, and reset the environment.
            returns.append(cur_return)
            lengths.append(cur_step)
            timestep = environment.reset()
            actor.observe_first(timestep)
            cur_return = 0.
            cur_step = 0

            if len(returns) % (FLAGS.num_episodes // 10) == 0:
                print(
                    f'Run time {time.time() - t_start:0.0f} secs, '
                    f'evaluated episode {len(returns)} / {FLAGS.num_episodes}')

    # Returned data include problem configs.
    results = {
        '_'.join(keys): value
        for keys, value in tree.flatten_with_path(problem_config)
    }

    # And computed results.
    results.update({
        'metric_value':
        np.mean(returns),
        'metric_std_dev':
        np.std(returns, ddof=0),
        'metric_std_err':
        np.std(returns, ddof=0) / np.sqrt(len(returns)),
        'length_mean':
        np.mean(lengths),
        'length_std':
        np.std(lengths, ddof=0),
        'num_episodes':
        len(returns),
    })
    logger.write(results)
示例#4
0
def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    dataset, dev_dataset, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=FLAGS.batch_size,
        max_dev_size=FLAGS.max_dev_size)
    environment_spec = specs.make_environment_spec(environment)

    # Create the networks to optimize.
    value_func, mixture_density = deepiv.make_ope_networks(
        problem_config['task_name'],
        environment_spec=environment_spec,
        density_layer_sizes=FLAGS.density_layer_sizes,
        value_layer_sizes=FLAGS.value_layer_sizes,
        num_cat=FLAGS.num_cat)

    # Load pretrained target policy network.
    target_policy_net = utils.load_policy_net(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        environment_spec=environment_spec)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # The learner updates the parameters (and initializes them).
    learner = deepiv.DeepIVLearner(
        value_func=value_func,
        mixture_density=mixture_density,
        policy_net=target_policy_net,
        discount=problem_config['discount'],
        value_learning_rate=FLAGS.value_learning_rate,
        density_learning_rate=FLAGS.density_learning_rate,
        n_sampling=FLAGS.n_sampling,
        density_iter=FLAGS.density_iter,
        dataset=dataset,
        counter=learner_counter)

    eval_counter = counting.Counter(counter, 'eval')
    eval_logger = loggers.TerminalLogger('eval')

    while True:
        learner.step()
        steps = learner.state['num_steps'].numpy()

        if steps % FLAGS.evaluate_every == 0:
            eval_results = {}
            if dev_dataset is not None:
                eval_results.update(learner.dev_loss(dev_dataset))
            eval_results.update(
                utils.ope_evaluation(
                    value_func=value_func,
                    policy_net=target_policy_net,
                    environment=environment,
                    num_init_samples=FLAGS.evaluate_init_samples,
                    discount=problem_config['discount'],
                    counter=eval_counter))
            eval_logger.write(eval_results)

        if steps >= FLAGS.density_iter + FLAGS.value_iter:
            break
示例#5
0
def main(_):
    problem_config = FLAGS.problem_config

    # Load the offline dataset and environment.
    dataset, dev_dataset, environment = utils.load_data_and_env(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        batch_size=FLAGS.batch_size,
        max_dev_size=FLAGS.max_dev_size)
    environment_spec = specs.make_environment_spec(environment)

    # Create the networks to optimize.
    value_func = fqe.make_ope_networks(problem_config['task_name'],
                                       environment_spec,
                                       distributional=FLAGS.distributional,
                                       layer_sizes=FLAGS.layer_sizes,
                                       vmin=FLAGS.vmin,
                                       vmax=FLAGS.vmax,
                                       num_atoms=FLAGS.num_atoms)
    target_value_func = copy.deepcopy(value_func)

    # Load pretrained target policy network.
    target_policy_net = utils.load_policy_net(
        task_name=problem_config['task_name'],
        noise_level=problem_config['noise_level'],
        near_policy_dataset=problem_config['near_policy_dataset'],
        dataset_path=FLAGS.dataset_path,
        environment_spec=environment_spec)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')
    logger = loggers.TerminalLogger('learner')

    # The learner updates the parameters (and initializes them).
    learner = fqe.FQELearner(policy_network=target_policy_net,
                             critic_network=value_func,
                             target_critic_network=target_value_func,
                             discount=problem_config['discount'],
                             target_update_period=FLAGS.target_update_period,
                             vmin=FLAGS.vmin,
                             vmax=FLAGS.vmax,
                             dataset=dataset,
                             distributional=FLAGS.distributional,
                             critic_lr=FLAGS.learning_rate,
                             counter=learner_counter,
                             logger=logger)

    eval_counter = counting.Counter(counter, 'eval')
    eval_logger = loggers.TerminalLogger('eval')

    while True:
        learner.step()
        steps = learner.state['num_steps'].numpy()

        if steps % FLAGS.evaluate_every == 0:
            eval_results = {}
            if dev_dataset is not None:
                eval_results = {
                    'dev_loss': learner.dev_critic_loss(dev_dataset)
                }
            eval_results.update(
                utils.ope_evaluation(
                    value_func=learner.critic_mean,
                    policy_net=target_policy_net,
                    environment=environment,
                    num_init_samples=FLAGS.evaluate_init_samples,
                    discount=problem_config['discount'],
                    counter=eval_counter))
            eval_logger.write(eval_results)

        if steps >= FLAGS.max_steps:
            break