def testLoad(self):
     saved_path, ckpt_at_path_1 = self._createModelsOnDisk()
     policy_at_0 = policy_loader.load(saved_path)
     self.assertEqual(0, policy_at_0.get_train_step())
     self.assertEqual(0, policy_at_0.variables()[0].numpy())
     policy_at_1 = policy_loader.load(saved_path, ckpt_at_path_1)
     self.assertEqual(1, policy_at_1.get_train_step())
     self.assertEqual(10, policy_at_1.variables()[0].numpy())
 def testMaterialize(self):
     saved_path, ckpt_at_path_1 = self._createModelsOnDisk()
     materialized_path = os.path.join(self.root_dir, 'material/001')
     policy_loader.materialize_saved_model(saved_path, ckpt_at_path_1,
                                           materialized_path)
     policy_at_1 = policy_loader.load(materialized_path)
     self.assertEqual(1, policy_at_1.get_train_step())
     self.assertEqual(10, policy_at_1.variables()[0].numpy())
예제 #3
0
def train_eval(agent_name='ppo',
               warmstart_policy_dir=None,
               num_policy_iterations=0,
               num_iterations=100,
               batch_size=64,
               train_sequence_length=1,
               deploy_policy_name='saved_policy'):
    """Train for LLVM inliner."""
    root_dir = FLAGS.root_dir

    # Initialize trainer and policy saver.
    time_step_spec, action_spec = config.create_signature_specs(config.CONFIG)
    tf_agent = agent_creators.create_agent(agent_name, time_step_spec,
                                           action_spec)
    llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent)
    policy_dict = {
        'saved_policy': tf_agent.policy,
        'saved_collect_policy': tf_agent.collect_policy,
    }
    saver = policy_saver.PolicySaver(policy_dict=policy_dict)

    if warmstart_policy_dir:
        warmstart_policy = policy_loader.load(warmstart_policy_dir)
        tf_agent.policy.update(policy=warmstart_policy,
                               tau=1.0,
                               tau_non_trainable=None,
                               sort_variables_by_name=False)

    with open(os.path.join(FLAGS.data_path, 'module_paths'), 'r') as f:
        module_paths = [
            os.path.join(FLAGS.data_path, name.rstrip('\n')) for name in f
        ]
        file_paths = [(path + '.bc', path + '.cmd') for path in module_paths]

    runner = inlining_runner.InliningRunner(
        clang_path=FLAGS.clang_path, llvm_size_path=FLAGS.llvm_size_path)

    sequence_example_iterator_fn = (
        data_reader.create_sequence_example_iterator_fn(
            agent_name=agent_name,
            config=config.CONFIG,
            batch_size=batch_size,
            train_sequence_length=train_sequence_length))

    data_collector = local_data_collector.LocalDataCollector(
        file_paths=file_paths,
        num_workers=FLAGS.num_workers,
        num_modules=FLAGS.num_modules,
        runner=runner.collect_data,
        parser=sequence_example_iterator_fn)

    for policy_iteration in range(num_policy_iterations):
        policy_path = os.path.join(root_dir, 'policy', str(policy_iteration))
        saver.save(policy_path)

        dataset_iter = data_collector.collect_data(
            policy_path=os.path.join(policy_path, deploy_policy_name))
        llvm_trainer.train(dataset_iter, num_iterations)

        data_collector.on_dataset_consumed(dataset_iter)

    # Save final policy.
    saver.save(root_dir)