def main(_): problem_config = FLAGS.problem_config # Load the offline dataset and environment. dataset, dev_dataset, environment = utils.load_data_and_env( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, batch_size=FLAGS.batch_size, max_dev_size=FLAGS.max_dev_size, shuffle=False, repeat=False) environment_spec = specs.make_environment_spec(environment) """ task_gamma_map = { 'bsuite_catch': 0.25, 'bsuite_mountain_car': 0.5, 'bsuite_cartpole': 0.44, } gamma = FLAGS.gamma or task_gamma_map[problem_config['task_name']] """ gamma = utils.get_median(problem_config['task_name'], environment_spec, dataset) # Create the networks to optimize. value_func, instrumental_feature = kiv_batch.make_ope_networks( problem_config['task_name'], environment_spec, n_component=FLAGS.n_component, gamma=gamma) # Load pretrained target policy network. target_policy_net = utils.load_policy_net( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, environment_spec=environment_spec) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') logger = loggers.TerminalLogger('learner') # The learner updates the parameters (and initializes them). num_batches = 0 for _ in dataset: num_batches += 1 stage1_batch = num_batches // 2 stage2_batch = num_batches - stage1_batch learner = kiv_batch.KIVLearner(value_func=value_func, instrumental_feature=instrumental_feature, policy_net=target_policy_net, discount=problem_config['discount'], stage1_reg=FLAGS.stage1_reg, stage2_reg=FLAGS.stage2_reg, stage1_batch=stage1_batch, stage2_batch=stage2_batch, dataset=dataset, valid_dataset=dev_dataset, counter=learner_counter, logger=logger, checkpoint=False) eval_counter = counting.Counter(counter, 'eval') eval_logger = loggers.TerminalLogger('eval') while True: results = { 'gamma': gamma, 'stage1_batch': stage1_batch, 'stage2_batch': stage2_batch, } # Include learner results in eval results for ease of analysis. results.update(learner.step()) results.update( utils.ope_evaluation(value_func=value_func, policy_net=target_policy_net, environment=environment, num_init_samples=FLAGS.evaluate_init_samples, discount=problem_config['discount'], counter=eval_counter)) eval_logger.write(results) if learner.state['num_steps'] >= FLAGS.max_steps: break
def main(_): problem_config = FLAGS.problem_config # Load the offline dataset and environment. dataset, dev_dataset, environment = utils.load_data_and_env( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, batch_size=FLAGS.batch_size, max_dev_size=FLAGS.max_dev_size) environment_spec = specs.make_environment_spec(environment) # Create the networks to optimize. value_func, instrumental_feature = dfiv.make_ope_networks( problem_config['task_name'], environment_spec, value_layer_sizes=FLAGS.value_layer_sizes, instrumental_layer_sizes=FLAGS.instrumental_layer_sizes) # Load pretrained target policy network. target_policy_net = utils.load_policy_net( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, environment_spec=environment_spec) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') logger = loggers.TerminalLogger('learner') # The learner updates the parameters (and initializes them). learner_cls = dfiv.DFIVLearner if FLAGS.learner2: learner_cls = dfiv.DFIV2Learner learner = learner_cls( value_func=value_func, instrumental_feature=instrumental_feature, policy_net=target_policy_net, discount=problem_config['discount'], value_learning_rate=FLAGS.value_learning_rate, instrumental_learning_rate=FLAGS.instrumental_learning_rate, stage1_reg=FLAGS.stage1_reg, stage2_reg=FLAGS.stage2_reg, value_reg=FLAGS.value_reg, instrumental_reg=FLAGS.instrumental_reg, instrumental_iter=FLAGS.instrumental_iter, value_iter=FLAGS.value_iter, dataset=dataset, d_tm1_weight=FLAGS.d_tm1_weight, counter=learner_counter, logger=logger) eval_counter = counting.Counter(counter, 'eval') eval_logger = loggers.TerminalLogger('eval') while True: learner.step() steps = learner.state['num_steps'].numpy() if steps % FLAGS.evaluate_every == 0: eval_results = {} if dev_dataset is not None: eval_results = {'dev_mse': learner.cal_validation_err(dev_dataset)} eval_results.update(utils.ope_evaluation( value_func=value_func, policy_net=target_policy_net, environment=environment, num_init_samples=FLAGS.evaluate_init_samples, discount=problem_config['discount'], counter=eval_counter)) eval_logger.write(eval_results) if steps >= FLAGS.max_steps: break
def main(_): problem_config = FLAGS.problem_config # Load the offline dataset and environment. _, _, environment = utils.load_data_and_env( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, batch_size=1) environment_spec = specs.make_environment_spec(environment) # Load pretrained target policy network. policy_net = utils.load_policy_net( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, environment_spec=environment_spec) actor = actors.FeedForwardActor(policy_network=policy_net) logger = loggers.TerminalLogger('ground_truth') discount = problem_config['discount'] returns = [] lengths = [] t_start = time.time() timestep = environment.reset() actor.observe_first(timestep) cur_return = 0. cur_step = 0 while len(returns) < FLAGS.num_episodes: action = actor.select_action(timestep.observation) timestep = environment.step(action) # Have the agent observe the timestep and let the actor update itself. actor.observe(action, next_timestep=timestep) cur_return += pow(discount, cur_step) * timestep.reward cur_step += 1 if timestep.last(): # Append return of the current episode, and reset the environment. returns.append(cur_return) lengths.append(cur_step) timestep = environment.reset() actor.observe_first(timestep) cur_return = 0. cur_step = 0 if len(returns) % (FLAGS.num_episodes // 10) == 0: print( f'Run time {time.time() - t_start:0.0f} secs, ' f'evaluated episode {len(returns)} / {FLAGS.num_episodes}') # Returned data include problem configs. results = { '_'.join(keys): value for keys, value in tree.flatten_with_path(problem_config) } # And computed results. results.update({ 'metric_value': np.mean(returns), 'metric_std_dev': np.std(returns, ddof=0), 'metric_std_err': np.std(returns, ddof=0) / np.sqrt(len(returns)), 'length_mean': np.mean(lengths), 'length_std': np.std(lengths, ddof=0), 'num_episodes': len(returns), }) logger.write(results)
def main(_): problem_config = FLAGS.problem_config # Load the offline dataset and environment. dataset, dev_dataset, environment = utils.load_data_and_env( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, batch_size=FLAGS.batch_size, max_dev_size=FLAGS.max_dev_size) environment_spec = specs.make_environment_spec(environment) # Create the networks to optimize. value_func, mixture_density = deepiv.make_ope_networks( problem_config['task_name'], environment_spec=environment_spec, density_layer_sizes=FLAGS.density_layer_sizes, value_layer_sizes=FLAGS.value_layer_sizes, num_cat=FLAGS.num_cat) # Load pretrained target policy network. target_policy_net = utils.load_policy_net( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, environment_spec=environment_spec) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # The learner updates the parameters (and initializes them). learner = deepiv.DeepIVLearner( value_func=value_func, mixture_density=mixture_density, policy_net=target_policy_net, discount=problem_config['discount'], value_learning_rate=FLAGS.value_learning_rate, density_learning_rate=FLAGS.density_learning_rate, n_sampling=FLAGS.n_sampling, density_iter=FLAGS.density_iter, dataset=dataset, counter=learner_counter) eval_counter = counting.Counter(counter, 'eval') eval_logger = loggers.TerminalLogger('eval') while True: learner.step() steps = learner.state['num_steps'].numpy() if steps % FLAGS.evaluate_every == 0: eval_results = {} if dev_dataset is not None: eval_results.update(learner.dev_loss(dev_dataset)) eval_results.update( utils.ope_evaluation( value_func=value_func, policy_net=target_policy_net, environment=environment, num_init_samples=FLAGS.evaluate_init_samples, discount=problem_config['discount'], counter=eval_counter)) eval_logger.write(eval_results) if steps >= FLAGS.density_iter + FLAGS.value_iter: break
def main(_): problem_config = FLAGS.problem_config # Load the offline dataset and environment. dataset, dev_dataset, environment = utils.load_data_and_env( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, batch_size=FLAGS.batch_size, max_dev_size=FLAGS.max_dev_size) environment_spec = specs.make_environment_spec(environment) # Create the networks to optimize. value_func = fqe.make_ope_networks(problem_config['task_name'], environment_spec, distributional=FLAGS.distributional, layer_sizes=FLAGS.layer_sizes, vmin=FLAGS.vmin, vmax=FLAGS.vmax, num_atoms=FLAGS.num_atoms) target_value_func = copy.deepcopy(value_func) # Load pretrained target policy network. target_policy_net = utils.load_policy_net( task_name=problem_config['task_name'], noise_level=problem_config['noise_level'], near_policy_dataset=problem_config['near_policy_dataset'], dataset_path=FLAGS.dataset_path, environment_spec=environment_spec) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') logger = loggers.TerminalLogger('learner') # The learner updates the parameters (and initializes them). learner = fqe.FQELearner(policy_network=target_policy_net, critic_network=value_func, target_critic_network=target_value_func, discount=problem_config['discount'], target_update_period=FLAGS.target_update_period, vmin=FLAGS.vmin, vmax=FLAGS.vmax, dataset=dataset, distributional=FLAGS.distributional, critic_lr=FLAGS.learning_rate, counter=learner_counter, logger=logger) eval_counter = counting.Counter(counter, 'eval') eval_logger = loggers.TerminalLogger('eval') while True: learner.step() steps = learner.state['num_steps'].numpy() if steps % FLAGS.evaluate_every == 0: eval_results = {} if dev_dataset is not None: eval_results = { 'dev_loss': learner.dev_critic_loss(dev_dataset) } eval_results.update( utils.ope_evaluation( value_func=learner.critic_mean, policy_net=target_policy_net, environment=environment, num_init_samples=FLAGS.evaluate_init_samples, discount=problem_config['discount'], counter=eval_counter)) eval_logger.write(eval_results) if steps >= FLAGS.max_steps: break