def testLoad(self): saved_path, ckpt_at_path_1 = self._createModelsOnDisk() policy_at_0 = policy_loader.load(saved_path) self.assertEqual(0, policy_at_0.get_train_step()) self.assertEqual(0, policy_at_0.variables()[0].numpy()) policy_at_1 = policy_loader.load(saved_path, ckpt_at_path_1) self.assertEqual(1, policy_at_1.get_train_step()) self.assertEqual(10, policy_at_1.variables()[0].numpy())
def testMaterialize(self): saved_path, ckpt_at_path_1 = self._createModelsOnDisk() materialized_path = os.path.join(self.root_dir, 'material/001') policy_loader.materialize_saved_model(saved_path, ckpt_at_path_1, materialized_path) policy_at_1 = policy_loader.load(materialized_path) self.assertEqual(1, policy_at_1.get_train_step()) self.assertEqual(10, policy_at_1.variables()[0].numpy())
def train_eval(agent_name='ppo', warmstart_policy_dir=None, num_policy_iterations=0, num_iterations=100, batch_size=64, train_sequence_length=1, deploy_policy_name='saved_policy'): """Train for LLVM inliner.""" root_dir = FLAGS.root_dir # Initialize trainer and policy saver. time_step_spec, action_spec = config.create_signature_specs(config.CONFIG) tf_agent = agent_creators.create_agent(agent_name, time_step_spec, action_spec) llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent) policy_dict = { 'saved_policy': tf_agent.policy, 'saved_collect_policy': tf_agent.collect_policy, } saver = policy_saver.PolicySaver(policy_dict=policy_dict) if warmstart_policy_dir: warmstart_policy = policy_loader.load(warmstart_policy_dir) tf_agent.policy.update(policy=warmstart_policy, tau=1.0, tau_non_trainable=None, sort_variables_by_name=False) with open(os.path.join(FLAGS.data_path, 'module_paths'), 'r') as f: module_paths = [ os.path.join(FLAGS.data_path, name.rstrip('\n')) for name in f ] file_paths = [(path + '.bc', path + '.cmd') for path in module_paths] runner = inlining_runner.InliningRunner( clang_path=FLAGS.clang_path, llvm_size_path=FLAGS.llvm_size_path) sequence_example_iterator_fn = ( data_reader.create_sequence_example_iterator_fn( agent_name=agent_name, config=config.CONFIG, batch_size=batch_size, train_sequence_length=train_sequence_length)) data_collector = local_data_collector.LocalDataCollector( file_paths=file_paths, num_workers=FLAGS.num_workers, num_modules=FLAGS.num_modules, runner=runner.collect_data, parser=sequence_example_iterator_fn) for policy_iteration in range(num_policy_iterations): policy_path = os.path.join(root_dir, 'policy', str(policy_iteration)) saver.save(policy_path) dataset_iter = data_collector.collect_data( policy_path=os.path.join(policy_path, deploy_policy_name)) llvm_trainer.train(dataset_iter, num_iterations) data_collector.on_dataset_consumed(dataset_iter) # Save final policy. saver.save(root_dir)