def train_eval_online( # Basic args. log_dir, agent_module, env_name='HalfCheetah-v2', # Train and eval args. total_train_steps=int(1e6), summary_freq=100, print_freq=1000, save_freq=int(1e8), eval_freq=5000, n_eval_episodes=20, # For saving a partially trained policy. eval_target=None, # Target return value to stop training. eval_target_n=2, # Stop after n consecutive evals above eval_target. # Agent train args. initial_explore_steps=10000, replay_buffer_size=int(1e6), model_params=(((200, 200),), 2), optimizers=(('adam', 0.001),), batch_size=256, weight_decays=(0.0,), update_freq=1, update_rate=0.005, discount=0.99, ): """Training a policy with online interaction.""" # Create tf_env to get specs. tf_env = train_eval_utils.env_factory(env_name) tf_env_test = train_eval_utils.env_factory(env_name) observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() # Initialize dataset. with tf.device('/cpu:0'): train_data = dataset.Dataset( observation_spec, action_spec, replay_buffer_size, circular=True, ) data_ckpt = tf.train.Checkpoint(data=train_data) data_ckpt_name = os.path.join(log_dir, 'replay') time_st_total = time.time() time_st = time.time() timed_at_step = 0 # Collect data from random policy. explore_policy = policies.ContinuousRandomPolicy(action_spec) steps_collected = 0 log_freq = 5000 logging.info('Collecting data ...') collector = train_eval_utils.DataCollector(tf_env, explore_policy, train_data) while steps_collected < initial_explore_steps: count = collector.collect_transition() steps_collected += count if (steps_collected % log_freq == 0 or steps_collected == initial_explore_steps) and count > 0: steps_per_sec = ((steps_collected - timed_at_step) / (time.time() - time_st)) timed_at_step = steps_collected time_st = time.time() logging.info('(%d/%d) steps collected at %.4g steps/s.', steps_collected, initial_explore_steps, steps_per_sec) # Construct agent. agent_flags = utils.Flags( action_spec=action_spec, model_params=model_params, optimizers=optimizers, batch_size=batch_size, weight_decays=weight_decays, update_freq=update_freq, update_rate=update_rate, discount=discount, train_data=train_data) agent_args = agent_module.Config(agent_flags).agent_args agent = agent_module.Agent(**vars(agent_args)) # Prepare savers for models and results. train_summary_dir = os.path.join(log_dir, 'train') eval_summary_dir = os.path.join(log_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_summary_dir) eval_summary_writers = collections.OrderedDict() for policy_key in agent.test_policies.keys(): eval_summary_writer = tf.compat.v2.summary.create_file_writer( os.path.join(eval_summary_dir, policy_key)) eval_summary_writers[policy_key] = eval_summary_writer agent_ckpt_name = os.path.join(log_dir, 'agent') eval_results = [] # Train agent. logging.info('Start training ....') time_st = time.time() timed_at_step = 0 target_partial_policy_saved = False collector = train_eval_utils.DataCollector( tf_env, agent.online_policy, train_data) for _ in range(total_train_steps): collector.collect_transition() agent.train_step() step = agent.global_step if step % summary_freq == 0 or step == total_train_steps: agent.write_train_summary(train_summary_writer) if step % print_freq == 0 or step == total_train_steps: agent.print_train_info() if step % eval_freq == 0 or step == total_train_steps: time_ed = time.time() time_cost = time_ed - time_st logging.info( 'Training at %.4g steps/s.', (step - timed_at_step) / time_cost) eval_result, eval_infos = train_eval_utils.eval_policies( tf_env_test, agent.test_policies, n_eval_episodes) eval_results.append([step] + eval_result) # Cecide whether to save a partially trained policy based on current model # performance. if (eval_target is not None and len(eval_results) >= eval_target_n and not target_partial_policy_saved): evals_ = list([eval_results[-(i + 1)][1] for i in range(eval_target_n)]) evals_ = np.array(evals_) if np.min(evals_) >= eval_target: agent.save(agent_ckpt_name + '_partial_target') dataset.save_copy(train_data, data_ckpt_name + '_partial_target') logging.info('A partially trained policy was saved at step %d,' ' with episodic return %.4g.', step, evals_[-1]) target_partial_policy_saved = True logging.info('Testing at step %d:', step) for policy_key, policy_info in eval_infos.items(): logging.info(utils.get_summary_str( step=None, info=policy_info, prefix=policy_key + ': ')) utils.write_summary(eval_summary_writers[policy_key], step, policy_info) time_st = time.time() timed_at_step = step if step % save_freq == 0: agent.save(agent_ckpt_name + '-' + str(step)) # Final save after training. agent.save(agent_ckpt_name + '_final') data_ckpt.write(data_ckpt_name + '_final') time_cost = time.time() - time_st_total logging.info('Training finished, time cost %.4gs.', time_cost) return np.array(eval_results)
def print_train_info(self): info = self._train_info step = self._global_step.numpy() summary_str = utils.get_summary_str(step, info) logging.info(summary_str)
def train_eval_offline( # Basic args. log_dir, data_file, agent_module, env_name='HalfCheetah-v2', n_train=int(1e6), shuffle_steps=0, seed=0, use_seed_for_data=False, # Train and eval args. total_train_steps=int(1e6), summary_freq=100, print_freq=1000, save_freq=int(2e4), eval_freq=5000, n_eval_episodes=20, # Agent args. model_params=(((200, 200), ), 2), behavior_ckpt_file=None, value_penalty=True, alpha=1.0, #model_params=((200, 200),), optimizers=(('adam', 0.001), ), batch_size=256, weight_decays=(0.0, ), update_freq=1, update_rate=0.005, discount=0.99, ): """Training a policy with a fixed dataset.""" # Create tf_env to get specs. print('[train_eval_offline.py] env_name=', env_name) print('[train_eval_offline.py] data_file=', data_file) print('[train_eval_offline.py] agent_module=', agent_module) print('[train_eval_offline.py] model_params=', model_params) print('[train_eval_offline.py] optimizers=', optimizers) print('[train_eval_offline.py] bckpt_file=', behavior_ckpt_file) print('[train_eval_offline.py] value_penalty=', value_penalty) tf_env = env_factory(env_name) observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() # Prepare data. full_data = get_offline_data(tf_env) # Split data. n_train = min(n_train, full_data.size) logging.info('n_train %s.', n_train) if use_seed_for_data: rand = np.random.RandomState(seed) else: rand = np.random.RandomState(0) shuffled_indices = utils.shuffle_indices_with_steps(n=full_data.size, steps=shuffle_steps, rand=rand) train_indices = shuffled_indices[:n_train] train_data = full_data.create_view(train_indices) # Create agent. agent_flags = utils.Flags(observation_spec=observation_spec, action_spec=action_spec, model_params=model_params, optimizers=optimizers, batch_size=batch_size, weight_decays=weight_decays, update_freq=update_freq, update_rate=update_rate, discount=discount, train_data=train_data) agent_args = agent_module.Config(agent_flags).agent_args my_agent_arg_dict = {} for k in vars(agent_args): my_agent_arg_dict[k] = vars(agent_args)[k] if 'brac_primal' in agent_module.__name__: my_agent_arg_dict['behavior_ckpt_file'] = behavior_ckpt_file my_agent_arg_dict['value_penalty'] = value_penalty my_agent_arg_dict['alpha'] = alpha print('agent:', agent_module.__name__) print('agent_args:', my_agent_arg_dict) #agent = agent_module.Agent(**vars(agent_args)) agent = agent_module.Agent(**my_agent_arg_dict) agent_ckpt_name = os.path.join(log_dir, 'agent') # Restore agent from checkpoint if there exists one. if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)): logging.info('Checkpoint found at %s.', agent_ckpt_name) agent.restore(agent_ckpt_name) # Train agent. train_summary_dir = os.path.join(log_dir, 'train') eval_summary_dir = os.path.join(log_dir, 'eval') train_summary_writer = tf0.compat.v2.summary.create_file_writer( train_summary_dir) eval_summary_writers = collections.OrderedDict() for policy_key in agent.test_policies.keys(): eval_summary_writer = tf0.compat.v2.summary.create_file_writer( os.path.join(eval_summary_dir, policy_key)) eval_summary_writers[policy_key] = eval_summary_writer eval_results = [] time_st_total = time.time() time_st = time.time() step = agent.global_step timed_at_step = step while step < total_train_steps: agent.train_step() step = agent.global_step if step % summary_freq == 0 or step == total_train_steps: agent.write_train_summary(train_summary_writer) if step % print_freq == 0 or step == total_train_steps: agent.print_train_info() if step % eval_freq == 0 or step == total_train_steps: time_ed = time.time() time_cost = time_ed - time_st logging.info('Training at %.4g steps/s.', (step - timed_at_step) / time_cost) eval_result, eval_infos = train_eval_utils.eval_policies( tf_env, agent.test_policies, n_eval_episodes) eval_results.append([step] + eval_result) with open(os.path.join(log_dir, 'results.txt'), 'a') as logfile: logfile.write(str(eval_result) + '\n') logging.info('Testing at step %d:', step) for policy_key, policy_info in eval_infos.items(): logging.info( utils.get_summary_str(step=None, info=policy_info, prefix=policy_key + ': ')) utils.write_summary(eval_summary_writers[policy_key], step, policy_info) time_st = time.time() timed_at_step = step if step % save_freq == 0: agent.save(agent_ckpt_name) logging.info('Agent saved at %s.', agent_ckpt_name) agent.save(agent_ckpt_name) time_cost = time.time() - time_st_total logging.info('Training finished, time cost %.4gs.', time_cost) return np.array(eval_results)
def train_eval_offline( # Basic args. log_dir, data_file, agent_module, env_name='HalfCheetah-v2', n_train=int(1e6), shuffle_steps=0, seed=0, use_seed_for_data=False, # Train and eval args. total_train_steps=int(1e6), summary_freq=100, print_freq=1000, save_freq=int(2e4), eval_freq=5000, n_eval_episodes=20, # Agent args. model_params=(((200, 200),), 2), optimizers=(('adam', 0.001),), batch_size=256, weight_decays=(0.0,), update_freq=1, update_rate=0.005, discount=0.99, ): """Training a policy with a fixed dataset.""" # Create tf_env to get specs. tf_env = train_eval_utils.env_factory(env_name) observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() # Prepare data. logging.info('Loading data from %s ...', data_file) data_size = utils.load_variable_from_ckpt(data_file, 'data._capacity') with tf.device('/cpu:0'): full_data = dataset.Dataset(observation_spec, action_spec, data_size) data_ckpt = tf.train.Checkpoint(data=full_data) data_ckpt.restore(data_file) # Split data. n_train = min(n_train, full_data.size) logging.info('n_train %s.', n_train) if use_seed_for_data: rand = np.random.RandomState(seed) else: rand = np.random.RandomState(0) shuffled_indices = utils.shuffle_indices_with_steps( n=full_data.size, steps=shuffle_steps, rand=rand) train_indices = shuffled_indices[:n_train] train_data = full_data.create_view(train_indices) # Create agent. agent_flags = utils.Flags( observation_spec=observation_spec, action_spec=action_spec, model_params=model_params, optimizers=optimizers, batch_size=batch_size, weight_decays=weight_decays, update_freq=update_freq, update_rate=update_rate, discount=discount, train_data=train_data) agent_args = agent_module.Config(agent_flags).agent_args agent = agent_module.Agent(**vars(agent_args)) agent_ckpt_name = os.path.join(log_dir, 'agent') # Restore agent from checkpoint if there exists one. if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)): logging.info('Checkpoint found at %s.', agent_ckpt_name) agent.restore(agent_ckpt_name) # Train agent. train_summary_dir = os.path.join(log_dir, 'train') eval_summary_dir = os.path.join(log_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_summary_dir) eval_summary_writers = collections.OrderedDict() for policy_key in agent.test_policies.keys(): eval_summary_writer = tf.compat.v2.summary.create_file_writer( os.path.join(eval_summary_dir, policy_key)) eval_summary_writers[policy_key] = eval_summary_writer eval_results = [] time_st_total = time.time() time_st = time.time() step = agent.global_step timed_at_step = step while step < total_train_steps: agent.train_step() step = agent.global_step if step % summary_freq == 0 or step == total_train_steps: agent.write_train_summary(train_summary_writer) if step % print_freq == 0 or step == total_train_steps: agent.print_train_info() if step % eval_freq == 0 or step == total_train_steps: time_ed = time.time() time_cost = time_ed - time_st logging.info( 'Training at %.4g steps/s.', (step - timed_at_step) / time_cost) eval_result, eval_infos = train_eval_utils.eval_policies( tf_env, agent.test_policies, n_eval_episodes) eval_results.append([step] + eval_result) logging.info('Testing at step %d:', step) for policy_key, policy_info in eval_infos.items(): logging.info(utils.get_summary_str( step=None, info=policy_info, prefix=policy_key+': ')) utils.write_summary(eval_summary_writers[policy_key], step, policy_info) time_st = time.time() timed_at_step = step if step % save_freq == 0: agent.save(agent_ckpt_name) logging.info('Agent saved at %s.', agent_ckpt_name) agent.save(agent_ckpt_name) time_cost = time.time() - time_st_total logging.info('Training finished, time cost %.4gs.', time_cost) return np.array(eval_results)