def train_eval(agent_name='behavioral_cloning', num_iterations=100, batch_size=64, train_sequence_length=1): """Train for LLVM inliner.""" root_dir = os.path.expanduser(FLAGS.root_dir) root_dir = os.path.normpath(root_dir) # Initialize trainer and policy saver. time_step_spec, action_spec = config.create_signature_specs(config.CONFIG) tf_agent = agent_creators.create_agent(agent_name, time_step_spec, action_spec) llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent) policy_dict = { 'saved_policy': tf_agent.policy, 'saved_collect_policy': tf_agent.collect_policy, } saver = policy_saver.PolicySaver(policy_dict=policy_dict) tfrecord_iterator_fn = data_reader.create_tfrecord_iterator_fn( agent_name=agent_name, config=config.CONFIG, batch_size=batch_size, train_sequence_length=train_sequence_length) # Train. dataset_iter = tfrecord_iterator_fn(FLAGS.data_path) llvm_trainer.train(dataset_iter, num_iterations) # Save final policy. saver.save(root_dir)
def test_create_dqn_agent(self): gin.bind_parameter('create_agent.policy_network', q_network.QNetwork) gin.bind_parameter('DqnAgent.optimizer', tf.compat.v1.train.AdamOptimizer()) tf_agent = agent_creators.create_agent( agent_name='dqn', time_step_spec=self._time_step_spec, action_spec=self._action_spec) self.assertIsInstance(tf_agent, dqn_agent.DqnAgent)
def test_create_ppo_agent(self): gin.bind_parameter('create_agent.policy_network', actor_distribution_network.ActorDistributionNetwork) gin.bind_parameter('PPOAgent.optimizer', tf.compat.v1.train.AdamOptimizer()) tf_agent = agent_creators.create_agent( agent_name='ppo', time_step_spec=self._time_step_spec, action_spec=self._action_spec) self.assertIsInstance(tf_agent, ppo_agent.PPOAgent)
def test_create_behavioral_cloning_agent(self): gin.bind_parameter('create_agent.policy_network', q_network.QNetwork) gin.bind_parameter('BehavioralCloningAgent.optimizer', tf.compat.v1.train.AdamOptimizer()) tf_agent = agent_creators.create_agent( agent_name='behavioral_cloning', time_step_spec=self._time_step_spec, action_spec=self._action_spec) self.assertIsInstance(tf_agent, behavioral_cloning_agent.BehavioralCloningAgent)
def train_eval(agent_name='ppo', warmstart_policy_dir=None, num_policy_iterations=0, num_iterations=100, batch_size=64, train_sequence_length=1, deploy_policy_name='saved_policy'): """Train for LLVM inliner.""" root_dir = FLAGS.root_dir # Initialize trainer and policy saver. time_step_spec, action_spec = config.create_signature_specs(config.CONFIG) tf_agent = agent_creators.create_agent(agent_name, time_step_spec, action_spec) llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent) policy_dict = { 'saved_policy': tf_agent.policy, 'saved_collect_policy': tf_agent.collect_policy, } saver = policy_saver.PolicySaver(policy_dict=policy_dict) if warmstart_policy_dir: warmstart_policy = policy_loader.load(warmstart_policy_dir) tf_agent.policy.update(policy=warmstart_policy, tau=1.0, tau_non_trainable=None, sort_variables_by_name=False) with open(os.path.join(FLAGS.data_path, 'module_paths'), 'r') as f: module_paths = [ os.path.join(FLAGS.data_path, name.rstrip('\n')) for name in f ] file_paths = [(path + '.bc', path + '.cmd') for path in module_paths] runner = inlining_runner.InliningRunner( clang_path=FLAGS.clang_path, llvm_size_path=FLAGS.llvm_size_path) sequence_example_iterator_fn = ( data_reader.create_sequence_example_iterator_fn( agent_name=agent_name, config=config.CONFIG, batch_size=batch_size, train_sequence_length=train_sequence_length)) data_collector = local_data_collector.LocalDataCollector( file_paths=file_paths, num_workers=FLAGS.num_workers, num_modules=FLAGS.num_modules, runner=runner.collect_data, parser=sequence_example_iterator_fn) for policy_iteration in range(num_policy_iterations): policy_path = os.path.join(root_dir, 'policy', str(policy_iteration)) saver.save(policy_path) dataset_iter = data_collector.collect_data( policy_path=os.path.join(policy_path, deploy_policy_name)) llvm_trainer.train(dataset_iter, num_iterations) data_collector.on_dataset_consumed(dataset_iter) # Save final policy. saver.save(root_dir)