def main(argv=()): with errors.clean_commandline_error_exit(): if len(argv) > 1: errors.log_and_raise( 'Command line parsing failure: call_variants does not accept ' 'positional arguments but some are present on the command line: ' '"{}".'.format(str(argv)), errors.CommandLineError) del argv # Unused. proto_utils.uses_fast_cpp_protos_or_die() logging_level.set_from_flag() if FLAGS.use_tpu: master = tf_utils.resolve_master(FLAGS.master, FLAGS.tpu_name, FLAGS.tpu_zone, FLAGS.gcp_project) else: master = '' model = modeling.get_model(FLAGS.model_name) call_variants( examples_filename=FLAGS.examples, checkpoint_path=FLAGS.checkpoint, model=model, execution_hardware=FLAGS.execution_hardware, output_file=FLAGS.outfile, max_batches=FLAGS.max_batches, batch_size=FLAGS.batch_size, master=master, use_tpu=FLAGS.use_tpu, )
def main(_): proto_utils.uses_fast_cpp_protos_or_die() if not FLAGS.dataset_config_pbtxt: logging.error('Need to specify --dataset_config_pbtxt') logging_level.set_from_flag() if FLAGS.kmp_blocktime: os.environ['KMP_BLOCKTIME'] = FLAGS.kmp_blocktime logging.info('Set KMP_BLOCKTIME to %s', os.environ['KMP_BLOCKTIME']) master = tf_utils.resolve_master( FLAGS.master, FLAGS.tpu_name, FLAGS.tpu_zone, FLAGS.gcp_project) if FLAGS.use_tpu else '' eval_loop( master=master, dataset_config_pbtxt=FLAGS.dataset_config_pbtxt, checkpoint_dir=FLAGS.checkpoint_dir, model_name=FLAGS.model_name, batch_size=FLAGS.batch_size, max_examples=FLAGS.max_examples, eval_name=FLAGS.eval_name, max_evaluations=FLAGS.max_evaluations, use_tpu=FLAGS.use_tpu, )
def parse_and_run(): """Parse TF_CONFIG to cluster_spec and call run(). TF_CONFIG environment variable is available when running using gcloud either locally or on cloud. It has all the information required to create a ClusterSpec which is important for running distributed code. Raises: ValueError: If flags are invalid. """ tf_config = os.environ.get('TF_CONFIG') logging.info('TF_CONFIG %s', tf_config) for name in ['master', 'task', 'ps_tasks']: if getattr(FLAGS, name) and tf_config: raise ValueError( 'Either the flag --%s or the environment variable TF_CONFIG can be' ' set but not both.' % name) # redacted # # If TF_CONFIG is not available we are either running locally in Cloud # or distributed inside Google. On Cloud the default values of # FLAGS.master and FLAGS.task correspond to running training locally. # Inside Google they will be set as needed to configure local or distributed # training. Inside Google we don't need to explicitly set worker_device # in replica_device_setter becaue this will be set automatically based # on various flags. if not tf_config: device_fn = tf.train.replica_device_setter(FLAGS.ps_tasks) master = tf_utils.resolve_master(FLAGS.master, FLAGS.tpu_name, FLAGS.tpu_zone, FLAGS.gcp_project) return run(master, FLAGS.task == 0, device_fn=device_fn, use_tpu=FLAGS.use_tpu) tf_config_json = json.loads(tf_config) cluster = tf_config_json.get('cluster') job_name = tf_config_json.get('task', {}).get('type') task_index = tf_config_json.get('task', {}).get('index') # If cluster information is empty run local if job_name is None or task_index is None: device_fn = tf.train.replica_device_setter(0) return run('', True, device_fn=device_fn, use_tpu=FLAGS.use_tpu) ps = cluster.get('ps', []) num_ps = len(ps) cluster_spec = tf.train.ClusterSpec(cluster) server = tf.train.Server(cluster_spec, job_name=job_name, task_index=task_index) if job_name == 'ps': server.join() return elif job_name in ['master', 'worker']: device_fn = tf.train.replica_device_setter( num_ps, worker_device='/job:%s/task:%d' % (job_name, task_index), cluster=cluster_spec) return run(server.target, job_name == 'master', device_fn=device_fn, use_tpu=FLAGS.use_tpu)