def collect_data( log_dir, data_config, n_samples=int(1e6), env_name='HalfCheetah-v2', log_freq=int(1e4), n_eval_episodes=20, ): """Creates dataset of transitions based on desired config.""" tf_env = train_eval_utils.env_factory(env_name) observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() # Initialize dataset. sample_sizes = list([cfg[-1] for cfg in data_config]) sample_sizes = get_sample_counts(n_samples, sample_sizes) with tf.device('/cpu:0'): data = dataset.Dataset(observation_spec, action_spec, n_samples, circular=False) data_ckpt = tf.train.Checkpoint(data=data) data_ckpt_name = os.path.join(log_dir, 'data') # Collect data for each policy in data_config. time_st = time.time() test_results = collections.OrderedDict() for (policy_name, policy_cfg, _), n_transitions in zip(data_config, sample_sizes): policy_cfg = policy_loader.parse_policy_cfg(policy_cfg) policy = policy_loader.load_policy(policy_cfg, action_spec) logging.info('Testing policy %s...', policy_name) eval_mean, eval_std = train_eval_utils.eval_policy_episodes( tf_env, policy, n_eval_episodes) test_results[policy_name] = [eval_mean, eval_std] logging.info('Return mean %.4g, std %.4g.', eval_mean, eval_std) logging.info('Collecting data from policy %s...', policy_name) collect_n_transitions(tf_env, policy, data, n_transitions, log_freq) # Save final dataset. assert data.size == data.capacity data_ckpt.write(data_ckpt_name) time_cost = time.time() - time_st logging.info( 'Finished: %d transitions collected, ' 'saved at %s, ' 'time cost %.4gs.', n_samples, data_ckpt_name, time_cost)
def get_offline_data(tf_env): gym_env = tf_env.pyenv.envs[0] # offline_dataset = gym_env.unwrapped.get_dataset() offline_dataset = gym_env.get_dataset() dataset_size = len(offline_dataset["observations"]) tf_dataset = dataset.Dataset( tf_env.observation_spec(), tf_env.action_spec(), size=dataset_size ) observation_dtype = tf_env.observation_spec().dtype action_dtype = tf_env.action_spec().dtype offline_dataset["terminals"] = np.squeeze(offline_dataset["terminals"]) offline_dataset["rewards"] = np.squeeze(offline_dataset["rewards"]) (nonterminal_steps,) = np.where( np.logical_and( np.logical_not(offline_dataset["terminals"]), np.arange(dataset_size) < dataset_size - 1, ) ) logging.info( "Found %d non-terminal steps out of a total of %d steps." % (len(nonterminal_steps), dataset_size) ) s1 = tf.convert_to_tensor( offline_dataset["observations"][nonterminal_steps], dtype=observation_dtype ) s2 = tf.convert_to_tensor( offline_dataset["observations"][nonterminal_steps + 1], dtype=observation_dtype ) a1 = tf.convert_to_tensor( offline_dataset["actions"][nonterminal_steps], dtype=action_dtype ) a2 = tf.convert_to_tensor( offline_dataset["actions"][nonterminal_steps + 1], dtype=action_dtype ) discount = tf.convert_to_tensor( 1.0 - offline_dataset["terminals"][nonterminal_steps + 1], dtype=tf.float32 ) reward = tf.convert_to_tensor( offline_dataset["rewards"][nonterminal_steps], dtype=tf.float32 ) transitions = dataset.Transition(s1, s2, a1, a2, discount, reward) tf_dataset.add_transitions(transitions) return tf_dataset
def train_eval_online( # Basic args. log_dir, agent_module, env_name='HalfCheetah-v2', # Train and eval args. total_train_steps=int(1e6), summary_freq=100, print_freq=1000, save_freq=int(1e8), eval_freq=5000, n_eval_episodes=20, # For saving a partially trained policy. eval_target=None, # Target return value to stop training. eval_target_n=2, # Stop after n consecutive evals above eval_target. # Agent train args. initial_explore_steps=10000, replay_buffer_size=int(1e6), model_params=(((200, 200),), 2), optimizers=(('adam', 0.001),), batch_size=256, weight_decays=(0.0,), update_freq=1, update_rate=0.005, discount=0.99, ): """Training a policy with online interaction.""" # Create tf_env to get specs. tf_env = train_eval_utils.env_factory(env_name) tf_env_test = train_eval_utils.env_factory(env_name) observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() # Initialize dataset. with tf.device('/cpu:0'): train_data = dataset.Dataset( observation_spec, action_spec, replay_buffer_size, circular=True, ) data_ckpt = tf.train.Checkpoint(data=train_data) data_ckpt_name = os.path.join(log_dir, 'replay') time_st_total = time.time() time_st = time.time() timed_at_step = 0 # Collect data from random policy. explore_policy = policies.ContinuousRandomPolicy(action_spec) steps_collected = 0 log_freq = 5000 logging.info('Collecting data ...') collector = train_eval_utils.DataCollector(tf_env, explore_policy, train_data) while steps_collected < initial_explore_steps: count = collector.collect_transition() steps_collected += count if (steps_collected % log_freq == 0 or steps_collected == initial_explore_steps) and count > 0: steps_per_sec = ((steps_collected - timed_at_step) / (time.time() - time_st)) timed_at_step = steps_collected time_st = time.time() logging.info('(%d/%d) steps collected at %.4g steps/s.', steps_collected, initial_explore_steps, steps_per_sec) # Construct agent. agent_flags = utils.Flags( action_spec=action_spec, model_params=model_params, optimizers=optimizers, batch_size=batch_size, weight_decays=weight_decays, update_freq=update_freq, update_rate=update_rate, discount=discount, train_data=train_data) agent_args = agent_module.Config(agent_flags).agent_args agent = agent_module.Agent(**vars(agent_args)) # Prepare savers for models and results. train_summary_dir = os.path.join(log_dir, 'train') eval_summary_dir = os.path.join(log_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_summary_dir) eval_summary_writers = collections.OrderedDict() for policy_key in agent.test_policies.keys(): eval_summary_writer = tf.compat.v2.summary.create_file_writer( os.path.join(eval_summary_dir, policy_key)) eval_summary_writers[policy_key] = eval_summary_writer agent_ckpt_name = os.path.join(log_dir, 'agent') eval_results = [] # Train agent. logging.info('Start training ....') time_st = time.time() timed_at_step = 0 target_partial_policy_saved = False collector = train_eval_utils.DataCollector( tf_env, agent.online_policy, train_data) for _ in range(total_train_steps): collector.collect_transition() agent.train_step() step = agent.global_step if step % summary_freq == 0 or step == total_train_steps: agent.write_train_summary(train_summary_writer) if step % print_freq == 0 or step == total_train_steps: agent.print_train_info() if step % eval_freq == 0 or step == total_train_steps: time_ed = time.time() time_cost = time_ed - time_st logging.info( 'Training at %.4g steps/s.', (step - timed_at_step) / time_cost) eval_result, eval_infos = train_eval_utils.eval_policies( tf_env_test, agent.test_policies, n_eval_episodes) eval_results.append([step] + eval_result) # Cecide whether to save a partially trained policy based on current model # performance. if (eval_target is not None and len(eval_results) >= eval_target_n and not target_partial_policy_saved): evals_ = list([eval_results[-(i + 1)][1] for i in range(eval_target_n)]) evals_ = np.array(evals_) if np.min(evals_) >= eval_target: agent.save(agent_ckpt_name + '_partial_target') dataset.save_copy(train_data, data_ckpt_name + '_partial_target') logging.info('A partially trained policy was saved at step %d,' ' with episodic return %.4g.', step, evals_[-1]) target_partial_policy_saved = True logging.info('Testing at step %d:', step) for policy_key, policy_info in eval_infos.items(): logging.info(utils.get_summary_str( step=None, info=policy_info, prefix=policy_key + ': ')) utils.write_summary(eval_summary_writers[policy_key], step, policy_info) time_st = time.time() timed_at_step = step if step % save_freq == 0: agent.save(agent_ckpt_name + '-' + str(step)) # Final save after training. agent.save(agent_ckpt_name + '_final') data_ckpt.write(data_ckpt_name + '_final') time_cost = time.time() - time_st_total logging.info('Training finished, time cost %.4gs.', time_cost) return np.array(eval_results)
def train_eval_offline( # Basic args. log_dir, data_file, agent_module, env_name='HalfCheetah-v2', n_train=int(1e6), shuffle_steps=0, seed=0, use_seed_for_data=False, # Train and eval args. total_train_steps=int(1e6), summary_freq=100, print_freq=1000, save_freq=int(2e4), eval_freq=5000, n_eval_episodes=20, # Agent args. model_params=(((200, 200),), 2), optimizers=(('adam', 0.001),), batch_size=256, weight_decays=(0.0,), update_freq=1, update_rate=0.005, discount=0.99, ): """Training a policy with a fixed dataset.""" # Create tf_env to get specs. tf_env = train_eval_utils.env_factory(env_name) observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() # Prepare data. logging.info('Loading data from %s ...', data_file) data_size = utils.load_variable_from_ckpt(data_file, 'data._capacity') with tf.device('/cpu:0'): full_data = dataset.Dataset(observation_spec, action_spec, data_size) data_ckpt = tf.train.Checkpoint(data=full_data) data_ckpt.restore(data_file) # Split data. n_train = min(n_train, full_data.size) logging.info('n_train %s.', n_train) if use_seed_for_data: rand = np.random.RandomState(seed) else: rand = np.random.RandomState(0) shuffled_indices = utils.shuffle_indices_with_steps( n=full_data.size, steps=shuffle_steps, rand=rand) train_indices = shuffled_indices[:n_train] train_data = full_data.create_view(train_indices) # Create agent. agent_flags = utils.Flags( observation_spec=observation_spec, action_spec=action_spec, model_params=model_params, optimizers=optimizers, batch_size=batch_size, weight_decays=weight_decays, update_freq=update_freq, update_rate=update_rate, discount=discount, train_data=train_data) agent_args = agent_module.Config(agent_flags).agent_args agent = agent_module.Agent(**vars(agent_args)) agent_ckpt_name = os.path.join(log_dir, 'agent') # Restore agent from checkpoint if there exists one. if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)): logging.info('Checkpoint found at %s.', agent_ckpt_name) agent.restore(agent_ckpt_name) # Train agent. train_summary_dir = os.path.join(log_dir, 'train') eval_summary_dir = os.path.join(log_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_summary_dir) eval_summary_writers = collections.OrderedDict() for policy_key in agent.test_policies.keys(): eval_summary_writer = tf.compat.v2.summary.create_file_writer( os.path.join(eval_summary_dir, policy_key)) eval_summary_writers[policy_key] = eval_summary_writer eval_results = [] time_st_total = time.time() time_st = time.time() step = agent.global_step timed_at_step = step while step < total_train_steps: agent.train_step() step = agent.global_step if step % summary_freq == 0 or step == total_train_steps: agent.write_train_summary(train_summary_writer) if step % print_freq == 0 or step == total_train_steps: agent.print_train_info() if step % eval_freq == 0 or step == total_train_steps: time_ed = time.time() time_cost = time_ed - time_st logging.info( 'Training at %.4g steps/s.', (step - timed_at_step) / time_cost) eval_result, eval_infos = train_eval_utils.eval_policies( tf_env, agent.test_policies, n_eval_episodes) eval_results.append([step] + eval_result) logging.info('Testing at step %d:', step) for policy_key, policy_info in eval_infos.items(): logging.info(utils.get_summary_str( step=None, info=policy_info, prefix=policy_key+': ')) utils.write_summary(eval_summary_writers[policy_key], step, policy_info) time_st = time.time() timed_at_step = step if step % save_freq == 0: agent.save(agent_ckpt_name) logging.info('Agent saved at %s.', agent_ckpt_name) agent.save(agent_ckpt_name) time_cost = time.time() - time_st_total logging.info('Training finished, time cost %.4gs.', time_cost) return np.array(eval_results)