flags.DEFINE_integer('bigtable_node_count', 3, 'Number of nodes to create in the bigtable cluster.') flags.DEFINE_enum('bigtable_storage_type', 'ssd', ['ssd', 'hdd'], 'Storage class for the cluster') flags.DEFINE_string('google_bigtable_zone', 'us-central1-b', 'Bigtable zone.') flags.DEFINE_boolean('bigtable_replication_cluster', False, 'Whether to create a Bigtable replication cluster.') flags.DEFINE_string('bigtable_replication_cluster_zone', None, 'Zone in which to create a Bigtable replication cluster.') flags.DEFINE_boolean('bigtable_multicluster_routing', False, 'Whether to use multi-cluster routing.') flags.register_multi_flags_validator( ['bigtable_replication_cluster', 'bigtable_replication_cluster_zone'], _ValidateReplicationFlags, message='bigtable_replication_cluster_zone must ' 'be set if bigtable_replication_cluster is True.') flags.register_multi_flags_validator( ['bigtable_replication_cluster', 'bigtable_multicluster_routing'], _ValidateRoutingFlags, message='bigtable_replication_cluster must ' 'be set if bigtable_multicluster_routing is True.') class GcpBigtableInstance(resource.BaseResource): """Object representing a GCP Bigtable Instance. Attributes: name: Instance and cluster name. num_nodes: Number of nodes in the instance's cluster.
help='Directory in which to save the output.') flags.DEFINE_string('runner', 'direct', help='Beam runner (`direct` or `dataflow`).') flags.mark_flag_as_required('output_dir') def validate_project_id_flag(inputs): needs_project_id = bool('dataflow' in inputs['runner'].lower() or inputs['input_tce_table']) return bool(inputs['project_id']) if needs_project_id else True flags.register_multi_flags_validator( ['project_id', 'input_tce_table', 'runner'], validate_project_id_flag, message='--project_id must be set if running on Dataflow or ' 'reading from BigQuery.') def read_tces(input_tce_csv_file, input_tce_table=None, project_id=None): """Read, filter, and partition a table of Kepler KOIs. Args: input_tce_csv_file: CSV file containing the Q1-Q17 DR24 Kepler TCE table. input_tce_table: BQ table name containing the Q1-Q17 DR24 Kepler TCE table. project_id: GCP project ID. Required if `input_tce_table` is passed. Returns: A `dict` with keys ['train', 'val', 'test'], where the values are lists of single `dict` TCE records as belonging to that subset. """
'cleaned up before deleting the network. If firewall ' 'rules are added manually, PKB will not know about all of ' 'them. However, they must be deleted in order to ' 'successfully delete the PKB-created network.') flags.DEFINE_enum('bq_client_interface', 'CLI', ['CLI', 'JAVA', 'SIMBA_JDBC_1_2_4_1007'], 'The Runtime Interface used when interacting with BigQuery.') flags.DEFINE_string('gcp_preemptible_status_bucket', None, 'The GCS bucket to store the preemptible status when ' 'running on GCP.') flags.DEFINE_integer( 'gcp_provisioned_iops', 100000, 'Iops to provision for pd-extreme. Defaults to the gcloud ' 'default of 100000.') API_OVERRIDE = flags.DEFINE_string( 'gcp_cloud_redis_api_override', default='https://redis.googleapis.com/', help='Cloud redis API endpoint override. Defaults to prod.') def _ValidatePreemptFlags(flags_dict): if flags_dict['gce_preemptible_vms']: return bool(flags_dict['gcp_preemptible_status_bucket']) return True flags.register_multi_flags_validator( ['gce_preemptible_vms', 'gcp_preemptible_status_bucket'], _ValidatePreemptFlags, 'When gce_preemptible_vms is specified, ' 'gcp_preemptible_status_bucket must be specified.')
if FLAGS.config_file: with tf.io.gfile.GFile(FLAGS.config_file, "r") as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.model_dir, "config.json") with tf.io.gfile.GFile(config_file, "w") as writer: writer.write(json.dumps(config, indent=4)) config["model_dir"] = FLAGS.model_dir if FLAGS.learning_rate is not None: config["learning_rate"] = FLAGS.learning_rate if FLAGS.per_device_batch_size is not None: config["per_device_batch_size"] = FLAGS.per_device_batch_size if FLAGS.num_train_steps is not None: config["num_train_steps"] = FLAGS.num_train_steps if FLAGS.warmup_steps is not None: config["warmup_steps"] = FLAGS.warmup_steps train(ml_collections.ConfigDict(config)) if __name__ == "__main__": flags.register_multi_flags_validator(["config", "config_file"], validate_config_flags, "Either --config or --config_file needs " "to be set.") app.run(main)
flags.DEFINE_string('model_dir', None, 'The working directory of the model') # See www.moderndescartes.com/essays/shuffle_viz for discussion on sizing flags.DEFINE_integer('shuffle_buffer_size', 20000, 'Size of buffer used to shuffle train examples.') flags.DEFINE_bool('use_tpu', False, 'Whether to use TPU for training.') flags.DEFINE_string( 'tpu_name', None, 'The Cloud TPU to use for training. This should be either the name used' 'when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.') flags.register_multi_flags_validator( ['use_tpu', 'tpu_name'], lambda flags: bool(flags['use_tpu']) == bool(flags['tpu_name']), 'If use_tpu is set, tpu_name must also be set.') flags.register_multi_flags_validator( ['lr_boundaries', 'lr_rates'], lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1, 'Number of learning rates must be exactly one greater than the number of boundaries' ) flags.DEFINE_integer( 'iterations_per_loop', 200, help=('Number of steps to run on TPU before outfeeding metrics to the CPU.' ' If the number of iterations in the loop would exceed the number of' ' train steps, the loop will exit before reaching' ' --iterations_per_loop. The larger this value is, the higher the'
flags.DEFINE_string('pretrained_model_folder', None, 'Model folder under ./trained_models if --eval is set') flags.DEFINE_float('recall_inference_bias', None, 'Recall bias value') flags.DEFINE_enum('token_embedding_dimension', None, ['100', '300'], 'Token embedding dimension size.') flags.DEFINE_integer('threads_tf', 32, 'Num threads for any tf op.') flags.DEFINE_integer( 'threads_prediction', 100, 'Num threads for eval prediction data partitioning ' 'into chucks for that amount of threads.') flags.mark_flags_as_required( ['dataset_text_folder', 'output_folder', 'token_embedding_dimension']) flags.mark_bool_flags_as_mutual_exclusive(['train', 'eval'], required=True) flags.register_multi_flags_validator( flag_names=['train', 'recall_inference_bias'], multi_flags_checker=lambda flags: not flags['train'] or flags[ 'recall_inference_bias'] in [None, 0.], message='In train mode, recall_inference_bias must be unset or zero.') flags.register_multi_flags_validator( flag_names=['eval', 'pretrained_model_folder'], multi_flags_checker=lambda flags: not flags['eval'] or all(flags.values()), message='In eval mode, all these flags must be set.')
def must_add_imu_sensor(): """ Returns true if the IMU sensor must be added. We don't add all sensors by default because they slow down the simulation """ return (FLAGS.imu or FLAGS.evaluation) # Flag validators. flags.register_multi_flags_validator( [ 'obstacle_detection', 'obstacle_detection_model_paths', 'obstacle_detection_model_names' ], lambda flags_dict: (not flags_dict['obstacle_detection'] or (flags_dict['obstacle_detection'] and (len(flags_dict['obstacle_detection_model_paths']) == len(flags_dict['obstacle_detection_model_names'])))), message='--obstacle_detection_model_paths and ' '--obstacle_detection_model_names must have the same length') def prediction_validator(flags_dict): if flags_dict['prediction']: return (flags_dict['obstacle_tracking'] or flags_dict['perfect_obstacle_tracking']) return True flags.register_multi_flags_validator(
flags.DEFINE_bool('record_lidar', False, 'True to record lidar') flags.DEFINE_bool('record_rgb_camera', False, 'True to record RGB camera') flags.DEFINE_bool( 'record_ground_truth', False, 'True to carla data (e.g., vehicle position, traffic lights)') # Other flags flags.DEFINE_integer('num_cameras', 5, 'Number of cameras.') # Flag validators. flags.register_validator('framework', lambda value: value == 'ros' or value == 'ray', message='--framework must be: ros | ray') flags.register_multi_flags_validator( ['replay', 'evaluate_obj_detection'], lambda flags_dict: not (flags_dict['replay'] and flags_dict[ 'evaluate_obj_detection']), message='--evaluate_obj_detection cannot be set when --replay is set') flags.register_multi_flags_validator( ['replay', 'fusion'], lambda flags_dict: not (flags_dict['replay'] and flags_dict['fusion']), message='--fusion cannot be set when --replay is set') # flags.register_multi_flags_validator( # ['ground_agent_operator', 'obj_detection', 'traffic_light_det', 'segmentation_drn', 'segmentation_dla'], # lambda flags_dict: (flags_dict['ground_agent_operator'] or # (flags_dict['obj_detection'] and # flags_dict['traffic_light_det'] and # (flags_dict['segmentation_drn'] or flags_dict['segmentation_dla']))), # message='ERDOS agent requires obj detection, segmentation and traffic light detection') flags.register_multi_flags_validator( [
train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # checkpoint. It also saves and logs them, so we don't do anything here. pass if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) flags.mark_flags_as_required(["dataset", "model_dir", "schedule"]) def _validate_schedule(flag_values): """Validates the --schedule flag and the flags it interacts with.""" schedule = flag_values["schedule"] save_checkpoints_steps = flag_values["save_checkpoints_steps"] save_checkpoints_secs = flag_values["save_checkpoints_secs"] if schedule in ["train", "train_and_eval"]: if not (save_checkpoints_steps or save_checkpoints_secs): raise flags.ValidationError( "--schedule='%s' requires --save_checkpoints_steps or " "--save_checkpoints_secs." % schedule) return True flags.register_multi_flags_validator( ["schedule", "save_checkpoints_steps", "save_checkpoints_secs"], _validate_schedule) tf.app.run()
flags.DEFINE_string('dataset_text_folder', 'i2b2-2014-paper', 'dataset to eval on') flags.DEFINE_enum('edim', None, ['100', '300'], 'Token embedding dimension') FLAGS = flags.FLAGS flags.mark_flags_as_required(['edim']) # Must specify only one verb. flags.mark_bool_flags_as_mutual_exclusive(['eval', 'eval-view-test-bias'], required=True) flags.register_multi_flags_validator( flag_names=['eval', 'rbias', 'pretrained_model_folder'], multi_flags_checker=lambda flags: not flags['eval'] or all(flags.values()), message='In eval mode, all these flags must be set.') def start_and_wait_for_jobs(processes, descriptions): assert len(processes) == len(descriptions) jobs = [] processes_descriptions = list(zip(processes, descriptions)) for i, (process, description) in enumerate(processes_descriptions): logging.info('Starting %s', description) jobs.append(process.run_bg()) if i != len(processes_descriptions) - 1: time.sleep(1) for running_jobs in more_itertools.repeatfunc(sum, None, (not j.ready() for j in jobs)): if running_jobs == 0:
# Optional flags.DEFINE_boolean('force', False, 'Force update, regardless of last update time') # Must be present if CM flags.DEFINE_boolean('cm_profiles', False, 'List available CM profiles.') flags.DEFINE_integer('profile', None, 'Campaign Manager profile id. Only needed for CM.') flags.DEFINE_boolean('cm_superuser', False, 'User is an _internal_ CM Superuser.') flags.DEFINE_integer('account', None, 'CM account id. RFequired for CM Superusers.') flags.register_multi_flags_validator( ['account', 'cm_superuser'], lambda value: (value.get('account') and value['cm_superuser']) or (not value['cm_superuser'] and not value['account']), message= '--account_id must be set for a superuser, and not set for normal users.') #flags.register_multi_flags_validator(['cm_id', 'profile'], # lambda value: (value['cm_id'] and not value['profile']) or (not value['cm_id'] and value['profile']), # message='profile must be set for a CM report to be specified') # Stub main() def main(unused_argv): if FLAGS.dv360_id: runner = DBMReportRunner(dbm_id=FLAGS.dv360_id, email=FLAGS.email, project=FLAGS.project)
FLAGS.zinbwave_dims, FLAGS.zinbwave_epsilon, FLAGS.zinbwave_keep_variance, FLAGS.zinbwave_gene_covariate, metrics.silhouette, metrics.ami, metrics.ari, metrics.kmeans_silhouette, adata.n_obs, FLAGS.tissue, n_clusters, ]) if FLAGS.output_h5ad: adata.write(FLAGS.output_h5ad) if __name__ == '__main__': flags.mark_flags_as_mutual_exclusive(['input_loom', 'input_csvs']) flags.mark_flag_as_required('output_csv') flags.mark_flag_as_required('reduced_dim') flags.mark_flag_as_required('tissue') flags.mark_flag_as_required('source') flags.register_multi_flags_validator( flag_names=( ['source'] + list(itertools.chain.from_iterable(_SOURCE_TO_FLAGS.values()))), multi_flags_checker=check_flags_combination, message='Source and other flags are not compatible.') app.run(main)
'https://cloud.google.com/spanner/docs/cpu-utilization#recommended-max.') _CPU_TARGET_HIGH_PRIORITY_UPPER_BOUND = flags.DEFINE_float( 'cloud_spanner_ycsb_cpu_optimization_target_max', 0.75, 'Maximum target CPU utilization after which the benchmark will throw an ' 'exception. This is needed so that in CPU-optimized mode, the increase in ' 'QPS does not overshoot the target CPU percentage by too much.') def _ValidateCpuTargetFlags(flags_dict): return (flags_dict['cloud_spanner_ycsb_cpu_optimization_target_max'] > flags_dict['cloud_spanner_ycsb_cpu_optimization_target']) flags.register_multi_flags_validator( [ 'cloud_spanner_ycsb_cpu_optimization_target', 'cloud_spanner_ycsb_cpu_optimization_target_max' ], _ValidateCpuTargetFlags, 'CPU optimization max target must be greater than target.') _CPU_OPTIMIZATION_INCREMENT_MINUTES = flags.DEFINE_integer( 'cloud_spanner_ycsb_cpu_optimization_workload_mins', 30, 'Length of time to run YCSB until incrementing QPS.') _CPU_OPTIMIZATION_MEASUREMENT_MINUTES = flags.DEFINE_integer( 'cloud_spanner_ycsb_cpu_optimization_measurement_mins', 5, 'Length of time to measure average CPU at the end of a test. For example, ' 'the default 5 means that only the last 5 minutes of the test will be ' 'used for representative CPU utilization.') _STARTING_QPS = flags.DEFINE_integer( 'cloud_spanner_ycsb_min_target', None, 'Starting QPS to set as YCSB target. Defaults to a value which uses the ' 'published throughput expectations for each node, see READ/WRITE caps per '
from absl import app, flags import dual_net_prim flags.DEFINE_string('model_path', None, 'Path to model to freeze') flags.mark_flag_as_required('model_path') flags.DEFINE_boolean( 'use_trt', False, 'True to write a GraphDef that uses the TRT runtime') flags.DEFINE_integer('trt_max_batch_size', None, 'Maximum TRT batch size') flags.DEFINE_string('trt_precision', 'fp32', 'Precision for TRT runtime: fp16, fp32 or int8') flags.register_multi_flags_validator( ['use_trt', 'trt_max_batch_size'], lambda flags: not flags['use_trt'] or flags['trt_max_batch_size'], 'trt_max_batch_size must be set if use_trt is true') FLAGS = flags.FLAGS def main(unused_argv): """Freeze a model to a GraphDef proto.""" dual_net_prim.freeze_graph(FLAGS.model_path, FLAGS.use_trt, FLAGS.trt_max_batch_size, FLAGS.trt_precision) if __name__ == "__main__": app.run(main)
flags.DEFINE_float('gnss_bias_lat', 0.0, 'Sets the bias on the latitude of the GNSS sensor.') flags.DEFINE_float('gnss_bias_lon', 0.0, 'Sets the bias on the longitude of the GNSS sensor.') def sensor_frequency_validator(flags_dict): return flags_dict['simulator_camera_frequency'] <= \ flags_dict['simulator_fps'] and \ flags_dict['simulator_lidar_frequency'] <= \ flags_dict['simulator_fps'] and \ flags_dict['simulator_imu_frequency'] <= \ flags_dict['simulator_fps'] and \ flags_dict['simulator_localization_frequency'] <= \ flags_dict['simulator_fps'] and \ flags_dict['simulator_control_frequency'] <= \ flags_dict['simulator_fps'] flags.register_multi_flags_validator( [ 'simulator_fps', 'simulator_camera_frequency', 'simulator_imu_frequency', 'simulator_lidar_frequency', 'simulator_localization_frequency', 'simulator_control_frequency', ], sensor_frequency_validator, message='Sensor frequencies cannot be greater than --simulator_fps')
with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader: config = json.load(reader) else: config = json.loads(FLAGS.config) # # Save config to workdir if it's not yet exists if jax.process_index() == 0: config_file = os.path.join(FLAGS.output_dir, 'config.json') with tf.io.gfile.GFile(config_file, 'w') as writer: writer.write(json.dumps(config, indent=4)) config['output_dir'] = FLAGS.output_dir if 'num_total_memories' not in config: config['num_total_memories'] = get_num_total_memories( ml_collections.ConfigDict(config)) generate(ml_collections.ConfigDict(config)) def validate_config_flags(flag_dict: Mapping[Text, Any]) -> bool: return flag_dict['config'] is not None or flag_dict[ 'config_file'] is not None if __name__ == '__main__': flags.register_multi_flags_validator( ['config', 'config_file'], validate_config_flags, 'Either --config or --config_file needs ' 'to be set.') app.run(main)
train_steps=FLAGS.train_steps): # continuous_train_and_eval() yields evaluation metrics after each # checkpoint. It also saves and logs them, so we don't do anything here. pass if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) flags.mark_flags_as_required(["dataset", "model_dir", "schedule"]) def _validate_schedule(flag_values): """Validates the --schedule flag and the flags it interacts with.""" schedule = flag_values["schedule"] save_checkpoints_steps = flag_values["save_checkpoints_steps"] save_checkpoints_secs = flag_values["save_checkpoints_secs"] if schedule in ["train", "train_and_eval"]: if not (save_checkpoints_steps or save_checkpoints_secs): raise flags.ValidationError( "--schedule='%s' requires --save_checkpoints_steps or " "--save_checkpoints_secs." % schedule) return True flags.register_multi_flags_validator( ["schedule", "save_checkpoints_steps", "save_checkpoints_secs"], _validate_schedule) tf.app.run()
flags.DEFINE_float('sgd_momentum', 0.9, 'Momentum parameter for learning rate.') # See www.moderndescartes.com/essays/shuffle_viz for discussion on sizing flags.DEFINE_integer('shuffle_buffer_size', 20000, 'Size of buffer used to shuffle train examples.') flags.DEFINE_bool('use_tpu', False, 'Whether to use TPU for training.') flags.DEFINE_string( 'tpu_name', None, 'The Cloud TPU to use for training. This should be either the name used' 'when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.') flags.register_multi_flags_validator( ['use_tpu', 'tpu_name'], lambda flags: bool(flags['use_tpu']) == bool(flags['tpu_name']), 'If use_tpu is set, tpu_name must also be set.') flags.DEFINE_integer( 'iterations_per_loop', 100, help=('Number of steps to run on TPU before outfeeding metrics to the CPU.' ' If the number of iterations in the loop would exceed the number of' ' train steps, the loop will exit before reaching' ' --iterations_per_loop. The larger this value is, the higher the' ' utilization on the TPU.')) flags.DEFINE_integer( 'num_tpu_cores', default=8, help=(
--hpcc_math_library=mkl Args: myflags: Dict of flags from register_multi_flags_validator for --hpcc_use_intel_compiled_hpl and --hpcc_math_library """ if myflags['hpcc_use_intel_compiled_hpl']: return myflags['hpcc_math_library'] == HPCC_MATH_LIBRARY_MKL return True flags.register_validator( 'hpcc_benchmarks', lambda hpcc_benchmarks: set(hpcc_benchmarks).issubset( set(HPCC_BENCHMARKS))) flags.register_multi_flags_validator( ['hpcc_math_library', 'hpcc_use_intel_compiled_hpl'], CheckUseIntelCompiled, 'With --hpcc_use_intel_compiled_hpl must specify ' f'--hpcc_math_library={HPCC_MATH_LIBRARY_MKL}') FLAGS = flags.FLAGS def _LimitBenchmarksToRun(vm, selected_hpcc_benchmarks): """Limits the benchmarks to run. This function copies hpcc.c to the local machine, comments out code that runs benchmarks not listed in selected_hpcc_benchmarks, and then copies hpcc.c back to the remote machine. Args: vm: The machine where hpcc.c was installed. selected_hpcc_benchmarks: A set of benchmarks to run.
def must_add_gnss_sensor(): """ Returns true if the GNSS sensor must be added. We don't add all sensors by default because they slow down the simulation """ return FLAGS.localization # Flag validators. flags.register_multi_flags_validator( [ 'obstacle_detection', 'obstacle_detection_model_paths', 'obstacle_detection_model_names' ], lambda flags_dict: (not flags_dict['obstacle_detection'] or (flags_dict['obstacle_detection'] and (len(flags_dict['obstacle_detection_model_paths']) == len(flags_dict['obstacle_detection_model_names'])))), message='--obstacle_detection_model_paths and ' '--obstacle_detection_model_names must have the same length') def prediction_validator(flags_dict): if flags_dict['prediction']: return (flags_dict['obstacle_tracking'] or flags_dict['perfect_obstacle_tracking']) return True flags.register_multi_flags_validator(
tf.compat.v1.enable_resource_variables() FLAGS(argv) # raises UnrecognizedFlagError for undefined flags tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param, skip_unknown=False) train_eval(FLAGS.root_dir, FLAGS.experiment_name, train_eval_dir=FLAGS.train_eval_dir) if __name__ == '__main__': def validate_mutual_exclusion(flags_dict): valid_1 = (flags_dict['root_dir'] is not None and flags_dict['experiment_name'] is not None and flags_dict['train_eval_dir'] is None) valid_2 = (flags_dict['root_dir'] is None and flags_dict['experiment_name'] is None and flags_dict['train_eval_dir'] is not None) if valid_1 or valid_2: return True message = ('Exactly both root_dir and experiment_name or only ' 'train_eval_dir must be specified.') raise flags.ValidationError(message) flags.register_multi_flags_validator( ['root_dir', 'experiment_name', 'train_eval_dir'], validate_mutual_exclusion) app.run(main)
flags.DEFINE_bool('carla_localization', False, 'True to use perfect localization') # Certain visualizations are not supported when running in challenge mode. def unsupported_visualizations_validator(flags_dict): return not (flags_dict['visualize_depth_camera'] or flags_dict['visualize_imu'] or flags_dict['visualize_pose'] or flags_dict['visualize_prediction']) flags.register_multi_flags_validator( [ 'visualize_depth_camera', 'visualize_imu', 'visualize_pose', 'visualize_prediction' ], unsupported_visualizations_validator, message='Trying to visualize unsupported_visualization') CENTER_CAMERA_LOCATION = pylot.utils.Location(0.0, 0.0, 2.0) CENTER_CAMERA_NAME = 'center_camera' LANE_CAMERA_LOCATION = pylot.utils.Location(1.3, 0.0, 1.8) LANE_CAMERA_NAME = 'lane_camera' TL_CAMERA_NAME = 'traffic_lights_camera' LEFT_CAMERA_NAME = 'left_camera' RIGHT_CAMERA_NAME = 'right_camera' def get_entry_point(): return 'ERDOSAgent'
help='Use Squeeze and Excitation with bias.') flags.DEFINE_integer( 'SE_ratio', 2, help='Squeeze and Excitation ratio.') flags.DEFINE_bool( 'use_swish', False, help=('Use Swish activation function inplace of ReLu. ' 'https://arxiv.org/pdf/1710.05941.pdf')) # TODO(seth): Verify if this is still required. flags.register_multi_flags_validator( ['use_tpu', 'iterations_per_loop', 'summary_steps'], lambda flags: (not flags['use_tpu'] or flags['summary_steps'] % flags['iterations_per_loop'] == 0), 'If use_tpu, summary_steps must be a multiple of iterations_per_loop') FLAGS = flags.FLAGS class DualNetwork(): def __init__(self, save_file): self.save_file = save_file self.inference_input = None self.inference_output = None config = tf.ConfigProto() config.gpu_options.allow_growth = True if FLAGS.use_mgpu_horovod: config.gpu_options.visible_device_list = str(hvd.local_rank())
'Fraction of positions to filter from golden chunks,' 'default, 1.0 (no filter)') flags.DEFINE_string('export_path', None, 'Where to export the model after training.') flags.DEFINE_bool('use_bt', False, 'Whether to use Bigtable as input. ' '(Only supported with --use_tpu, currently.)') flags.DEFINE_bool('freeze', False, 'Whether to freeze the graph at the end of training.') flags.register_multi_flags_validator( ['use_bt', 'use_tpu'], lambda flags: flags['use_tpu'] if flags['use_bt'] else True, '`use_bt` flag only valid with `use_tpu` as well') @flags.multi_flags_validator( ['num_examples', 'steps_to_train', 'filter_amount'], '`num_examples` requires `steps_to_train==0` and `filter_amount==1.0`') def _example_flags_validator(flags_dict): if not flags_dict['num_examples']: return True return not flags_dict['steps_to_train'] and flags_dict['filter_amount'] == 1.0 @flags.multi_flags_validator( ['use_bt', 'cbt_project', 'cbt_instance', 'cbt_table'], message='Cloud Bigtable configuration flags not correct') def _bt_checker(flags_dict): if not flags_dict['use_bt']:
from absl import flags from tensorflow import gfile import shipname flags.DEFINE_string( 'base_dir', None, 'Root directory if using local FS as the database.' 'Leave blank if using bucket_name.') flags.DEFINE_string( 'bucket_name', None, 'Bucket name if using GCS as the filesystem DB.' 'Leave blank is using base_dir.') flags.register_multi_flags_validator( ['base_dir', 'bucket_name'], lambda flags: bool(flags['base_dir']) != bool(flags['bucket_name']), 'Exactly one of --base_dir, --bucket_name must be set!') FLAGS = flags.FLAGS def _with_base(*args): def inner(): base_dir = FLAGS.base_dir or 'gs://{}'.format(FLAGS.bucket_name) return os.path.join(base_dir, *args) return inner # Functions to compute various important directories, based on FLAGS input. models_dir = _with_base('models')
flags.DEFINE_integer('trunk_layers', go.N, 'The number of resnet layers in the shared trunk.') flags.DEFINE_multi_integer( 'lr_boundaries', [400000, 600000], 'The number of steps at which the learning rate will decay') flags.DEFINE_multi_float('lr_rates', [0.01, 0.001, 0.0001], 'The different learning rates') flags.DEFINE_integer('training_seed', 0, 'Random seed to use for training and validation') flags.register_multi_flags_validator( ['lr_boundaries', 'lr_rates'], lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1, 'Number of learning rates must be exactly one greater than the number of boundaries' ) flags.DEFINE_float('l2_strength', 1e-4, 'The L2 regularization parameter applied to weights.') flags.DEFINE_float( 'value_cost_weight', 1.0, 'Scalar for value_cost, AGZ paper suggests 1/100 for ' 'supervised learning') flags.DEFINE_float('sgd_momentum', 0.9, 'Momentum parameter for learning rate.') flags.DEFINE_string(
'use_bt', False, 'Whether to use Bigtable as input. ' '(Only supported with --use_tpu, currently.)') flags.DEFINE_bool('freeze', False, 'Whether to freeze the graph at the end of training.') flags.DEFINE_boolean('profile_hvd', False, 'Whether to profile horovod based multi-gpu training.') flags.DEFINE_boolean('use_trt', False, 'True to write a GraphDef that uses the TRT runtime') flags.DEFINE_integer('trt_max_batch_size', None, 'Maximum TRT batch size') flags.DEFINE_string('trt_precision', 'fp32', 'Precision for TRT runtime: fp16, fp32 or int8') flags.register_multi_flags_validator( ['use_trt', 'trt_max_batch_size'], lambda flags: not flags['use_trt'] or flags['trt_max_batch_size'], 'trt_max_batch_size must be set if use_trt is true') flags.register_multi_flags_validator( ['use_bt', 'use_tpu'], lambda flags: flags['use_tpu'] if flags['use_bt'] else True, '`use_bt` flag only valid with `use_tpu` as well') @flags.multi_flags_validator( ['num_examples', 'steps_to_train', 'filter_amount'], '`num_examples` requires `steps_to_train==0` and `filter_amount==1.0`') def _example_flags_validator(flags_dict): if not flags_dict['num_examples']: return True return not flags_dict['steps_to_train'] and flags_dict[
######################################## # Recording operators. ######################################## flags.DEFINE_string('data_path', 'data/', 'Path where to logged data') flags.DEFINE_bool('log_detector_output', False, 'Enable recording of bbox annotated detector images') flags.DEFINE_bool('log_traffic_light_detector_output', False, 'Enable recording of bbox annotated tl detector images') # Flag validators. flags.register_multi_flags_validator( [ 'obstacle_detection', 'obstacle_detection_model_paths', 'obstacle_detection_model_names' ], lambda flags_dict: (not flags_dict['obstacle_detection'] or (flags_dict['obstacle_detection'] and (len(flags_dict['obstacle_detection_model_paths']) == len(flags_dict['obstacle_detection_model_names'])))), message='--obstacle_detection_model_paths and ' '--obstacle_detection_model_names must have the same length') def prediction_validator(flags_dict): if flags_dict['prediction']: return (flags_dict['obstacle_tracking'] or flags_dict['perfect_obstacle_tracking']) return True flags.register_multi_flags_validator(
flags.DEFINE_integer('trunk_layers', go.N, 'The number of resnet layers in the shared trunk.') flags.DEFINE_multi_integer( 'lr_boundaries', [400000, 600000], 'The number of steps at which the learning rate will decay') flags.DEFINE_multi_float('lr_rates', [0.01, 0.001, 0.0001], 'The different learning rates') flags.DEFINE_integer('training_seed', 0, 'Random seed to use for training and validation') flags.register_multi_flags_validator( ['lr_boundaries', 'lr_rates'], lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1, 'Number of learning rates must be exactly one greater than the number of boundaries' ) flags.DEFINE_float('l2_strength', 1e-4, 'The L2 regularization parameter applied to weights.') flags.DEFINE_float( 'value_cost_weight', 1.0, 'Scalar for value_cost, AGZ paper suggests 1/100 for ' 'supervised learning') flags.DEFINE_float('sgd_momentum', 0.9, 'Momentum parameter for learning rate.') flags.DEFINE_string(