Пример #1
0
def main(argv=()):
    with errors.clean_commandline_error_exit():
        if len(argv) > 1:
            errors.log_and_raise(
                'Command line parsing failure: call_variants does not accept '
                'positional arguments but some are present on the command line: '
                '"{}".'.format(str(argv)), errors.CommandLineError)
        del argv  # Unused.
        proto_utils.uses_fast_cpp_protos_or_die()

        logging_level.set_from_flag()

        if FLAGS.use_tpu:
            master = tf_utils.resolve_master(FLAGS.master, FLAGS.tpu_name,
                                             FLAGS.tpu_zone, FLAGS.gcp_project)
        else:
            master = ''

        model = modeling.get_model(FLAGS.model_name)
        call_variants(
            examples_filename=FLAGS.examples,
            checkpoint_path=FLAGS.checkpoint,
            model=model,
            execution_hardware=FLAGS.execution_hardware,
            output_file=FLAGS.outfile,
            max_batches=FLAGS.max_batches,
            batch_size=FLAGS.batch_size,
            master=master,
            use_tpu=FLAGS.use_tpu,
        )
Пример #2
0
def main(_):
    proto_utils.uses_fast_cpp_protos_or_die()

    if not FLAGS.dataset_config_pbtxt:
        logging.error('Need to specify --dataset_config_pbtxt')
    logging_level.set_from_flag()

    if FLAGS.kmp_blocktime:
        os.environ['KMP_BLOCKTIME'] = FLAGS.kmp_blocktime
        logging.info('Set KMP_BLOCKTIME to %s', os.environ['KMP_BLOCKTIME'])

    master = tf_utils.resolve_master(
        FLAGS.master, FLAGS.tpu_name, FLAGS.tpu_zone,
        FLAGS.gcp_project) if FLAGS.use_tpu else ''
    eval_loop(
        master=master,
        dataset_config_pbtxt=FLAGS.dataset_config_pbtxt,
        checkpoint_dir=FLAGS.checkpoint_dir,
        model_name=FLAGS.model_name,
        batch_size=FLAGS.batch_size,
        max_examples=FLAGS.max_examples,
        eval_name=FLAGS.eval_name,
        max_evaluations=FLAGS.max_evaluations,
        use_tpu=FLAGS.use_tpu,
    )
Пример #3
0
def parse_and_run():
    """Parse TF_CONFIG to cluster_spec and call run().

  TF_CONFIG environment variable is available when running using
  gcloud either locally or on cloud. It has all the information required
  to create a ClusterSpec which is important for running distributed code.

  Raises:
    ValueError: If flags are invalid.
  """
    tf_config = os.environ.get('TF_CONFIG')
    logging.info('TF_CONFIG %s', tf_config)

    for name in ['master', 'task', 'ps_tasks']:
        if getattr(FLAGS, name) and tf_config:
            raise ValueError(
                'Either the flag --%s or the environment variable TF_CONFIG can be'
                ' set but not both.' % name)

    # redacted
    #
    # If TF_CONFIG is not available we are either running locally in Cloud
    # or distributed inside Google. On Cloud the default values of
    # FLAGS.master and FLAGS.task correspond to running training locally.
    # Inside Google they will be set as needed to configure local or distributed
    # training. Inside Google we don't need to explicitly set worker_device
    # in replica_device_setter becaue this will be set automatically based
    # on various flags.
    if not tf_config:
        device_fn = tf.train.replica_device_setter(FLAGS.ps_tasks)
        master = tf_utils.resolve_master(FLAGS.master, FLAGS.tpu_name,
                                         FLAGS.tpu_zone, FLAGS.gcp_project)
        return run(master,
                   FLAGS.task == 0,
                   device_fn=device_fn,
                   use_tpu=FLAGS.use_tpu)

    tf_config_json = json.loads(tf_config)

    cluster = tf_config_json.get('cluster')
    job_name = tf_config_json.get('task', {}).get('type')
    task_index = tf_config_json.get('task', {}).get('index')

    # If cluster information is empty run local
    if job_name is None or task_index is None:
        device_fn = tf.train.replica_device_setter(0)
        return run('', True, device_fn=device_fn, use_tpu=FLAGS.use_tpu)

    ps = cluster.get('ps', [])
    num_ps = len(ps)

    cluster_spec = tf.train.ClusterSpec(cluster)
    server = tf.train.Server(cluster_spec,
                             job_name=job_name,
                             task_index=task_index)

    if job_name == 'ps':
        server.join()
        return
    elif job_name in ['master', 'worker']:
        device_fn = tf.train.replica_device_setter(
            num_ps,
            worker_device='/job:%s/task:%d' % (job_name, task_index),
            cluster=cluster_spec)
        return run(server.target,
                   job_name == 'master',
                   device_fn=device_fn,
                   use_tpu=FLAGS.use_tpu)