コード例 #1
0
def train_eval(agent_name='behavioral_cloning',
               num_iterations=100,
               batch_size=64,
               train_sequence_length=1):
    """Train for LLVM inliner."""
    root_dir = os.path.expanduser(FLAGS.root_dir)
    root_dir = os.path.normpath(root_dir)

    # Initialize trainer and policy saver.
    time_step_spec, action_spec = config.create_signature_specs(config.CONFIG)
    tf_agent = agent_creators.create_agent(agent_name, time_step_spec,
                                           action_spec)
    llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent)
    policy_dict = {
        'saved_policy': tf_agent.policy,
        'saved_collect_policy': tf_agent.collect_policy,
    }
    saver = policy_saver.PolicySaver(policy_dict=policy_dict)

    tfrecord_iterator_fn = data_reader.create_tfrecord_iterator_fn(
        agent_name=agent_name,
        config=config.CONFIG,
        batch_size=batch_size,
        train_sequence_length=train_sequence_length)

    # Train.
    dataset_iter = tfrecord_iterator_fn(FLAGS.data_path)
    llvm_trainer.train(dataset_iter, num_iterations)

    # Save final policy.
    saver.save(root_dir)
コード例 #2
0
 def test_create_dqn_agent(self):
     gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
     gin.bind_parameter('DqnAgent.optimizer',
                        tf.compat.v1.train.AdamOptimizer())
     tf_agent = agent_creators.create_agent(
         agent_name='dqn',
         time_step_spec=self._time_step_spec,
         action_spec=self._action_spec)
     self.assertIsInstance(tf_agent, dqn_agent.DqnAgent)
コード例 #3
0
 def test_create_ppo_agent(self):
     gin.bind_parameter('create_agent.policy_network',
                        actor_distribution_network.ActorDistributionNetwork)
     gin.bind_parameter('PPOAgent.optimizer',
                        tf.compat.v1.train.AdamOptimizer())
     tf_agent = agent_creators.create_agent(
         agent_name='ppo',
         time_step_spec=self._time_step_spec,
         action_spec=self._action_spec)
     self.assertIsInstance(tf_agent, ppo_agent.PPOAgent)
コード例 #4
0
 def test_create_behavioral_cloning_agent(self):
     gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
     gin.bind_parameter('BehavioralCloningAgent.optimizer',
                        tf.compat.v1.train.AdamOptimizer())
     tf_agent = agent_creators.create_agent(
         agent_name='behavioral_cloning',
         time_step_spec=self._time_step_spec,
         action_spec=self._action_spec)
     self.assertIsInstance(tf_agent,
                           behavioral_cloning_agent.BehavioralCloningAgent)
コード例 #5
0
def train_eval(agent_name='ppo',
               warmstart_policy_dir=None,
               num_policy_iterations=0,
               num_iterations=100,
               batch_size=64,
               train_sequence_length=1,
               deploy_policy_name='saved_policy'):
    """Train for LLVM inliner."""
    root_dir = FLAGS.root_dir

    # Initialize trainer and policy saver.
    time_step_spec, action_spec = config.create_signature_specs(config.CONFIG)
    tf_agent = agent_creators.create_agent(agent_name, time_step_spec,
                                           action_spec)
    llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent)
    policy_dict = {
        'saved_policy': tf_agent.policy,
        'saved_collect_policy': tf_agent.collect_policy,
    }
    saver = policy_saver.PolicySaver(policy_dict=policy_dict)

    if warmstart_policy_dir:
        warmstart_policy = policy_loader.load(warmstart_policy_dir)
        tf_agent.policy.update(policy=warmstart_policy,
                               tau=1.0,
                               tau_non_trainable=None,
                               sort_variables_by_name=False)

    with open(os.path.join(FLAGS.data_path, 'module_paths'), 'r') as f:
        module_paths = [
            os.path.join(FLAGS.data_path, name.rstrip('\n')) for name in f
        ]
        file_paths = [(path + '.bc', path + '.cmd') for path in module_paths]

    runner = inlining_runner.InliningRunner(
        clang_path=FLAGS.clang_path, llvm_size_path=FLAGS.llvm_size_path)

    sequence_example_iterator_fn = (
        data_reader.create_sequence_example_iterator_fn(
            agent_name=agent_name,
            config=config.CONFIG,
            batch_size=batch_size,
            train_sequence_length=train_sequence_length))

    data_collector = local_data_collector.LocalDataCollector(
        file_paths=file_paths,
        num_workers=FLAGS.num_workers,
        num_modules=FLAGS.num_modules,
        runner=runner.collect_data,
        parser=sequence_example_iterator_fn)

    for policy_iteration in range(num_policy_iterations):
        policy_path = os.path.join(root_dir, 'policy', str(policy_iteration))
        saver.save(policy_path)

        dataset_iter = data_collector.collect_data(
            policy_path=os.path.join(policy_path, deploy_policy_name))
        llvm_trainer.train(dataset_iter, num_iterations)

        data_collector.on_dataset_consumed(dataset_iter)

    # Save final policy.
    saver.save(root_dir)