flags.DEFINE_integer('bigtable_node_count', 3,
                     'Number of nodes to create in the bigtable cluster.')
flags.DEFINE_enum('bigtable_storage_type', 'ssd', ['ssd', 'hdd'],
                  'Storage class for the cluster')
flags.DEFINE_string('google_bigtable_zone', 'us-central1-b', 'Bigtable zone.')
flags.DEFINE_boolean('bigtable_replication_cluster', False,
                     'Whether to create a Bigtable replication cluster.')
flags.DEFINE_string('bigtable_replication_cluster_zone', None,
                    'Zone in which to create a Bigtable replication cluster.')
flags.DEFINE_boolean('bigtable_multicluster_routing', False,
                     'Whether to use multi-cluster routing.')
flags.register_multi_flags_validator(
    ['bigtable_replication_cluster', 'bigtable_replication_cluster_zone'],
    _ValidateReplicationFlags,
    message='bigtable_replication_cluster_zone must '
    'be set if bigtable_replication_cluster is True.')
flags.register_multi_flags_validator(
    ['bigtable_replication_cluster', 'bigtable_multicluster_routing'],
    _ValidateRoutingFlags,
    message='bigtable_replication_cluster must '
    'be set if bigtable_multicluster_routing is True.')


class GcpBigtableInstance(resource.BaseResource):
    """Object representing a GCP Bigtable Instance.

  Attributes:
    name: Instance and cluster name.
    num_nodes: Number of nodes in the instance's cluster.
Example #2
0
                    help='Directory in which to save the output.')
flags.DEFINE_string('runner',
                    'direct',
                    help='Beam runner (`direct` or `dataflow`).')
flags.mark_flag_as_required('output_dir')


def validate_project_id_flag(inputs):
    needs_project_id = bool('dataflow' in inputs['runner'].lower()
                            or inputs['input_tce_table'])
    return bool(inputs['project_id']) if needs_project_id else True


flags.register_multi_flags_validator(
    ['project_id', 'input_tce_table', 'runner'],
    validate_project_id_flag,
    message='--project_id must be set if running on Dataflow or '
    'reading from BigQuery.')


def read_tces(input_tce_csv_file, input_tce_table=None, project_id=None):
    """Read, filter, and partition a table of Kepler KOIs.

  Args:
    input_tce_csv_file: CSV file containing the Q1-Q17 DR24 Kepler TCE table.
    input_tce_table: BQ table name containing the Q1-Q17 DR24 Kepler TCE table.
    project_id: GCP project ID. Required if `input_tce_table` is passed.
  Returns:
    A `dict` with keys ['train', 'val', 'test'], where the values are lists of
      single `dict` TCE records as belonging to that subset.
  """
Example #3
0
                     'cleaned up before deleting the network. If firewall '
                     'rules are added manually, PKB will not know about all of '
                     'them. However, they must be deleted in order to '
                     'successfully delete the PKB-created network.')
flags.DEFINE_enum('bq_client_interface', 'CLI',
                  ['CLI', 'JAVA', 'SIMBA_JDBC_1_2_4_1007'],
                  'The Runtime Interface used when interacting with BigQuery.')
flags.DEFINE_string('gcp_preemptible_status_bucket', None,
                    'The GCS bucket to store the preemptible status when '
                    'running on GCP.')
flags.DEFINE_integer(
    'gcp_provisioned_iops', 100000,
    'Iops to provision for pd-extreme. Defaults to the gcloud '
    'default of 100000.')
API_OVERRIDE = flags.DEFINE_string(
    'gcp_cloud_redis_api_override',
    default='https://redis.googleapis.com/',
    help='Cloud redis API endpoint override. Defaults to prod.')


def _ValidatePreemptFlags(flags_dict):
  if flags_dict['gce_preemptible_vms']:
    return bool(flags_dict['gcp_preemptible_status_bucket'])
  return True


flags.register_multi_flags_validator(
    ['gce_preemptible_vms', 'gcp_preemptible_status_bucket'],
    _ValidatePreemptFlags, 'When gce_preemptible_vms is specified, '
    'gcp_preemptible_status_bucket must be specified.')
Example #4
0
  if FLAGS.config_file:
    with tf.io.gfile.GFile(FLAGS.config_file, "r") as reader:
      config = json.load(reader)
  else:
    config = json.loads(FLAGS.config)
    # # Save config to workdir if it's not yet exists
    if jax.process_index() == 0:
      config_file = os.path.join(FLAGS.model_dir, "config.json")
      with tf.io.gfile.GFile(config_file, "w") as writer:
        writer.write(json.dumps(config, indent=4))

  config["model_dir"] = FLAGS.model_dir
  if FLAGS.learning_rate is not None:
    config["learning_rate"] = FLAGS.learning_rate
  if FLAGS.per_device_batch_size is not None:
    config["per_device_batch_size"] = FLAGS.per_device_batch_size
  if FLAGS.num_train_steps is not None:
    config["num_train_steps"] = FLAGS.num_train_steps
  if FLAGS.warmup_steps is not None:
    config["warmup_steps"] = FLAGS.warmup_steps

  train(ml_collections.ConfigDict(config))


if __name__ == "__main__":
  flags.register_multi_flags_validator(["config", "config_file"],
                                       validate_config_flags,
                                       "Either --config or --config_file needs "
                                       "to be set.")
  app.run(main)
Example #5
0
flags.DEFINE_string('model_dir', None, 'The working directory of the model')

# See www.moderndescartes.com/essays/shuffle_viz for discussion on sizing
flags.DEFINE_integer('shuffle_buffer_size', 20000,
                     'Size of buffer used to shuffle train examples.')

flags.DEFINE_bool('use_tpu', False, 'Whether to use TPU for training.')

flags.DEFINE_string(
    'tpu_name', None,
    'The Cloud TPU to use for training. This should be either the name used'
    'when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.')

flags.register_multi_flags_validator(
    ['use_tpu', 'tpu_name'],
    lambda flags: bool(flags['use_tpu']) == bool(flags['tpu_name']),
    'If use_tpu is set, tpu_name must also be set.')

flags.register_multi_flags_validator(
    ['lr_boundaries', 'lr_rates'],
    lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1,
    'Number of learning rates must be exactly one greater than the number of boundaries'
)

flags.DEFINE_integer(
    'iterations_per_loop',
    200,
    help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'
          ' If the number of iterations in the loop would exceed the number of'
          ' train steps, the loop will exit before reaching'
          ' --iterations_per_loop. The larger this value is, the higher the'
Example #6
0
flags.DEFINE_string('pretrained_model_folder', None,
                    'Model folder under ./trained_models if --eval is set')

flags.DEFINE_float('recall_inference_bias', None, 'Recall bias value')

flags.DEFINE_enum('token_embedding_dimension', None, ['100', '300'],
                  'Token embedding dimension size.')

flags.DEFINE_integer('threads_tf', 32, 'Num threads for any tf op.')

flags.DEFINE_integer(
    'threads_prediction', 100,
    'Num threads for eval prediction data partitioning '
    'into chucks for that amount of threads.')

flags.mark_flags_as_required(
    ['dataset_text_folder', 'output_folder', 'token_embedding_dimension'])

flags.mark_bool_flags_as_mutual_exclusive(['train', 'eval'], required=True)

flags.register_multi_flags_validator(
    flag_names=['train', 'recall_inference_bias'],
    multi_flags_checker=lambda flags: not flags['train'] or flags[
        'recall_inference_bias'] in [None, 0.],
    message='In train mode, recall_inference_bias must be unset or zero.')

flags.register_multi_flags_validator(
    flag_names=['eval', 'pretrained_model_folder'],
    multi_flags_checker=lambda flags: not flags['eval'] or all(flags.values()),
    message='In eval mode, all these flags must be set.')
Example #7
0
def must_add_imu_sensor():
    """ Returns true if the IMU sensor must be added.

    We don't add all sensors by default because they slow down the simulation
    """
    return (FLAGS.imu or FLAGS.evaluation)


# Flag validators.
flags.register_multi_flags_validator(
    [
        'obstacle_detection', 'obstacle_detection_model_paths',
        'obstacle_detection_model_names'
    ],
    lambda flags_dict: (not flags_dict['obstacle_detection'] or
                        (flags_dict['obstacle_detection'] and
                         (len(flags_dict['obstacle_detection_model_paths']) ==
                          len(flags_dict['obstacle_detection_model_names'])))),
    message='--obstacle_detection_model_paths and '
    '--obstacle_detection_model_names must have the same length')


def prediction_validator(flags_dict):
    if flags_dict['prediction']:
        return (flags_dict['obstacle_tracking']
                or flags_dict['perfect_obstacle_tracking'])
    return True


flags.register_multi_flags_validator(
Example #8
0
flags.DEFINE_bool('record_lidar', False, 'True to record lidar')
flags.DEFINE_bool('record_rgb_camera', False, 'True to record RGB camera')
flags.DEFINE_bool(
    'record_ground_truth', False,
    'True to carla data (e.g., vehicle position, traffic lights)')

# Other flags
flags.DEFINE_integer('num_cameras', 5, 'Number of cameras.')

# Flag validators.
flags.register_validator('framework',
                         lambda value: value == 'ros' or value == 'ray',
                         message='--framework must be: ros | ray')
flags.register_multi_flags_validator(
    ['replay', 'evaluate_obj_detection'],
    lambda flags_dict: not (flags_dict['replay'] and flags_dict[
        'evaluate_obj_detection']),
    message='--evaluate_obj_detection cannot be set when --replay is set')
flags.register_multi_flags_validator(
    ['replay', 'fusion'],
    lambda flags_dict: not (flags_dict['replay'] and flags_dict['fusion']),
    message='--fusion cannot be set when --replay is set')
# flags.register_multi_flags_validator(
#     ['ground_agent_operator', 'obj_detection', 'traffic_light_det', 'segmentation_drn', 'segmentation_dla'],
#     lambda flags_dict: (flags_dict['ground_agent_operator'] or
#                         (flags_dict['obj_detection'] and
#                          flags_dict['traffic_light_det'] and
#                          (flags_dict['segmentation_drn'] or flags_dict['segmentation_dla']))),
#     message='ERDOS agent requires obj detection, segmentation and traffic light detection')
flags.register_multi_flags_validator(
    [
Example #9
0
        train_steps=FLAGS.train_steps):
      # continuous_train_and_eval() yields evaluation metrics after each
      # checkpoint. It also saves and logs them, so we don't do anything here.
      pass


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)

  flags.mark_flags_as_required(["dataset", "model_dir", "schedule"])

  def _validate_schedule(flag_values):
    """Validates the --schedule flag and the flags it interacts with."""
    schedule = flag_values["schedule"]
    save_checkpoints_steps = flag_values["save_checkpoints_steps"]
    save_checkpoints_secs = flag_values["save_checkpoints_secs"]

    if schedule in ["train", "train_and_eval"]:
      if not (save_checkpoints_steps or save_checkpoints_secs):
        raise flags.ValidationError(
            "--schedule='%s' requires --save_checkpoints_steps or "
            "--save_checkpoints_secs." % schedule)

    return True

  flags.register_multi_flags_validator(
      ["schedule", "save_checkpoints_steps", "save_checkpoints_secs"],
      _validate_schedule)

  tf.app.run()
Example #10
0
flags.DEFINE_string('dataset_text_folder', 'i2b2-2014-paper',
                    'dataset to eval on')

flags.DEFINE_enum('edim', None, ['100', '300'], 'Token embedding dimension')

FLAGS = flags.FLAGS

flags.mark_flags_as_required(['edim'])

# Must specify only one verb.
flags.mark_bool_flags_as_mutual_exclusive(['eval', 'eval-view-test-bias'],
                                          required=True)

flags.register_multi_flags_validator(
    flag_names=['eval', 'rbias', 'pretrained_model_folder'],
    multi_flags_checker=lambda flags: not flags['eval'] or all(flags.values()),
    message='In eval mode, all these flags must be set.')


def start_and_wait_for_jobs(processes, descriptions):
    assert len(processes) == len(descriptions)
    jobs = []
    processes_descriptions = list(zip(processes, descriptions))
    for i, (process, description) in enumerate(processes_descriptions):
        logging.info('Starting %s', description)
        jobs.append(process.run_bg())
        if i != len(processes_descriptions) - 1:
            time.sleep(1)
    for running_jobs in more_itertools.repeatfunc(sum, None, (not j.ready()
                                                              for j in jobs)):
        if running_jobs == 0:
Example #11
0
# Optional
flags.DEFINE_boolean('force', False,
                     'Force update, regardless of last update time')

# Must be present if CM
flags.DEFINE_boolean('cm_profiles', False, 'List available CM profiles.')
flags.DEFINE_integer('profile', None,
                     'Campaign Manager profile id. Only needed for CM.')
flags.DEFINE_boolean('cm_superuser', False,
                     'User is an _internal_ CM Superuser.')
flags.DEFINE_integer('account', None,
                     'CM account id. RFequired for CM Superusers.')

flags.register_multi_flags_validator(
    ['account', 'cm_superuser'],
    lambda value: (value.get('account') and value['cm_superuser']) or
    (not value['cm_superuser'] and not value['account']),
    message=
    '--account_id must be set for a superuser, and not set for normal users.')

#flags.register_multi_flags_validator(['cm_id', 'profile'],
#                        lambda value: (value['cm_id'] and not value['profile']) or (not value['cm_id'] and value['profile']),
#                        message='profile must be set for a CM report to be specified')


# Stub main()
def main(unused_argv):
    if FLAGS.dv360_id:
        runner = DBMReportRunner(dbm_id=FLAGS.dv360_id,
                                 email=FLAGS.email,
                                 project=FLAGS.project)
Example #12
0
                FLAGS.zinbwave_dims,
                FLAGS.zinbwave_epsilon,
                FLAGS.zinbwave_keep_variance,
                FLAGS.zinbwave_gene_covariate,
                metrics.silhouette,
                metrics.ami,
                metrics.ari,
                metrics.kmeans_silhouette,
                adata.n_obs,
                FLAGS.tissue,
                n_clusters,
            ])

        if FLAGS.output_h5ad:
            adata.write(FLAGS.output_h5ad)


if __name__ == '__main__':
    flags.mark_flags_as_mutual_exclusive(['input_loom', 'input_csvs'])
    flags.mark_flag_as_required('output_csv')
    flags.mark_flag_as_required('reduced_dim')
    flags.mark_flag_as_required('tissue')
    flags.mark_flag_as_required('source')
    flags.register_multi_flags_validator(
        flag_names=(
            ['source'] +
            list(itertools.chain.from_iterable(_SOURCE_TO_FLAGS.values()))),
        multi_flags_checker=check_flags_combination,
        message='Source and other flags are not compatible.')
    app.run(main)
Example #13
0
    'https://cloud.google.com/spanner/docs/cpu-utilization#recommended-max.')
_CPU_TARGET_HIGH_PRIORITY_UPPER_BOUND = flags.DEFINE_float(
    'cloud_spanner_ycsb_cpu_optimization_target_max', 0.75,
    'Maximum target CPU utilization after which the benchmark will throw an '
    'exception. This is needed so that in CPU-optimized mode, the increase in '
    'QPS does not overshoot the target CPU percentage by too much.')


def _ValidateCpuTargetFlags(flags_dict):
    return (flags_dict['cloud_spanner_ycsb_cpu_optimization_target_max'] >
            flags_dict['cloud_spanner_ycsb_cpu_optimization_target'])


flags.register_multi_flags_validator(
    [
        'cloud_spanner_ycsb_cpu_optimization_target',
        'cloud_spanner_ycsb_cpu_optimization_target_max'
    ], _ValidateCpuTargetFlags,
    'CPU optimization max target must be greater than target.')

_CPU_OPTIMIZATION_INCREMENT_MINUTES = flags.DEFINE_integer(
    'cloud_spanner_ycsb_cpu_optimization_workload_mins', 30,
    'Length of time to run YCSB until incrementing QPS.')
_CPU_OPTIMIZATION_MEASUREMENT_MINUTES = flags.DEFINE_integer(
    'cloud_spanner_ycsb_cpu_optimization_measurement_mins', 5,
    'Length of time to measure average CPU at the end of a test. For example, '
    'the default 5 means that only the last 5 minutes of the test will be '
    'used for representative CPU utilization.')
_STARTING_QPS = flags.DEFINE_integer(
    'cloud_spanner_ycsb_min_target', None,
    'Starting QPS to set as YCSB target. Defaults to a value which uses the '
    'published throughput expectations for each node, see READ/WRITE caps per '
from absl import app, flags

import dual_net_prim

flags.DEFINE_string('model_path', None, 'Path to model to freeze')

flags.mark_flag_as_required('model_path')

flags.DEFINE_boolean(
    'use_trt', False, 'True to write a GraphDef that uses the TRT runtime')
flags.DEFINE_integer('trt_max_batch_size', None,
                     'Maximum TRT batch size')
flags.DEFINE_string('trt_precision', 'fp32',
                    'Precision for TRT runtime: fp16, fp32 or int8')
flags.register_multi_flags_validator(
    ['use_trt', 'trt_max_batch_size'],
    lambda flags: not flags['use_trt'] or flags['trt_max_batch_size'],
    'trt_max_batch_size must be set if use_trt is true')

FLAGS = flags.FLAGS


def main(unused_argv):
    """Freeze a model to a GraphDef proto."""
    dual_net_prim.freeze_graph(FLAGS.model_path, FLAGS.use_trt,
                               FLAGS.trt_max_batch_size, FLAGS.trt_precision)


if __name__ == "__main__":
    app.run(main)
Example #15
0
File: flags.py Project: ymote/pylot
flags.DEFINE_float('gnss_bias_lat', 0.0,
                   'Sets the bias on the latitude of the GNSS sensor.')
flags.DEFINE_float('gnss_bias_lon', 0.0,
                   'Sets the bias on the longitude of the GNSS sensor.')


def sensor_frequency_validator(flags_dict):
    return flags_dict['simulator_camera_frequency'] <= \
        flags_dict['simulator_fps'] and \
        flags_dict['simulator_lidar_frequency'] <= \
        flags_dict['simulator_fps'] and \
        flags_dict['simulator_imu_frequency'] <= \
        flags_dict['simulator_fps'] and \
        flags_dict['simulator_localization_frequency'] <= \
        flags_dict['simulator_fps'] and \
        flags_dict['simulator_control_frequency'] <= \
        flags_dict['simulator_fps']


flags.register_multi_flags_validator(
    [
        'simulator_fps',
        'simulator_camera_frequency',
        'simulator_imu_frequency',
        'simulator_lidar_frequency',
        'simulator_localization_frequency',
        'simulator_control_frequency',
    ],
    sensor_frequency_validator,
    message='Sensor frequencies cannot be greater than --simulator_fps')
        with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader:
            config = json.load(reader)
    else:
        config = json.loads(FLAGS.config)
        # # Save config to workdir if it's not yet exists
        if jax.process_index() == 0:
            config_file = os.path.join(FLAGS.output_dir, 'config.json')
            with tf.io.gfile.GFile(config_file, 'w') as writer:
                writer.write(json.dumps(config, indent=4))

    config['output_dir'] = FLAGS.output_dir

    if 'num_total_memories' not in config:
        config['num_total_memories'] = get_num_total_memories(
            ml_collections.ConfigDict(config))

    generate(ml_collections.ConfigDict(config))


def validate_config_flags(flag_dict: Mapping[Text, Any]) -> bool:
    return flag_dict['config'] is not None or flag_dict[
        'config_file'] is not None


if __name__ == '__main__':
    flags.register_multi_flags_validator(
        ['config', 'config_file'], validate_config_flags,
        'Either --config or --config_file needs '
        'to be set.')
    app.run(main)
        train_steps=FLAGS.train_steps):
      # continuous_train_and_eval() yields evaluation metrics after each
      # checkpoint. It also saves and logs them, so we don't do anything here.
      pass


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)

  flags.mark_flags_as_required(["dataset", "model_dir", "schedule"])

  def _validate_schedule(flag_values):
    """Validates the --schedule flag and the flags it interacts with."""
    schedule = flag_values["schedule"]
    save_checkpoints_steps = flag_values["save_checkpoints_steps"]
    save_checkpoints_secs = flag_values["save_checkpoints_secs"]

    if schedule in ["train", "train_and_eval"]:
      if not (save_checkpoints_steps or save_checkpoints_secs):
        raise flags.ValidationError(
            "--schedule='%s' requires --save_checkpoints_steps or "
            "--save_checkpoints_secs." % schedule)

    return True

  flags.register_multi_flags_validator(
      ["schedule", "save_checkpoints_steps", "save_checkpoints_secs"],
      _validate_schedule)

  tf.app.run()
Example #18
0
flags.DEFINE_float('sgd_momentum', 0.9,
                   'Momentum parameter for learning rate.')

# See www.moderndescartes.com/essays/shuffle_viz for discussion on sizing
flags.DEFINE_integer('shuffle_buffer_size', 20000,
                     'Size of buffer used to shuffle train examples.')

flags.DEFINE_bool('use_tpu', False, 'Whether to use TPU for training.')

flags.DEFINE_string(
    'tpu_name', None,
    'The Cloud TPU to use for training. This should be either the name used'
    'when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.')

flags.register_multi_flags_validator(
    ['use_tpu', 'tpu_name'],
    lambda flags: bool(flags['use_tpu']) == bool(flags['tpu_name']),
    'If use_tpu is set, tpu_name must also be set.')

flags.DEFINE_integer(
    'iterations_per_loop',
    100,
    help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'
          ' If the number of iterations in the loop would exceed the number of'
          ' train steps, the loop will exit before reaching'
          ' --iterations_per_loop. The larger this value is, the higher the'
          ' utilization on the TPU.'))

flags.DEFINE_integer(
    'num_tpu_cores',
    default=8,
    help=(
Example #19
0
  --hpcc_math_library=mkl

  Args:
    myflags: Dict of flags from register_multi_flags_validator for
      --hpcc_use_intel_compiled_hpl and --hpcc_math_library
  """
    if myflags['hpcc_use_intel_compiled_hpl']:
        return myflags['hpcc_math_library'] == HPCC_MATH_LIBRARY_MKL
    return True


flags.register_validator(
    'hpcc_benchmarks', lambda hpcc_benchmarks: set(hpcc_benchmarks).issubset(
        set(HPCC_BENCHMARKS)))
flags.register_multi_flags_validator(
    ['hpcc_math_library', 'hpcc_use_intel_compiled_hpl'],
    CheckUseIntelCompiled, 'With --hpcc_use_intel_compiled_hpl must specify '
    f'--hpcc_math_library={HPCC_MATH_LIBRARY_MKL}')

FLAGS = flags.FLAGS


def _LimitBenchmarksToRun(vm, selected_hpcc_benchmarks):
    """Limits the benchmarks to run.

  This function copies hpcc.c to the local machine, comments out code that runs
  benchmarks not listed in selected_hpcc_benchmarks, and then copies hpcc.c back
  to the remote machine.

  Args:
    vm: The machine where hpcc.c was installed.
    selected_hpcc_benchmarks: A set of benchmarks to run.
Example #20
0
def must_add_gnss_sensor():
    """ Returns true if the GNSS sensor must be added.

    We don't add all sensors by default because they slow down the simulation
    """
    return FLAGS.localization


# Flag validators.
flags.register_multi_flags_validator(
    [
        'obstacle_detection', 'obstacle_detection_model_paths',
        'obstacle_detection_model_names'
    ],
    lambda flags_dict: (not flags_dict['obstacle_detection'] or
                        (flags_dict['obstacle_detection'] and
                         (len(flags_dict['obstacle_detection_model_paths']) ==
                          len(flags_dict['obstacle_detection_model_names'])))),
    message='--obstacle_detection_model_paths and '
    '--obstacle_detection_model_names must have the same length')


def prediction_validator(flags_dict):
    if flags_dict['prediction']:
        return (flags_dict['obstacle_tracking']
                or flags_dict['perfect_obstacle_tracking'])
    return True


flags.register_multi_flags_validator(
Example #21
0
    tf.compat.v1.enable_resource_variables()
    FLAGS(argv)  # raises UnrecognizedFlagError for undefined flags
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                        FLAGS.gin_param,
                                        skip_unknown=False)
    train_eval(FLAGS.root_dir,
               FLAGS.experiment_name,
               train_eval_dir=FLAGS.train_eval_dir)


if __name__ == '__main__':

    def validate_mutual_exclusion(flags_dict):
        valid_1 = (flags_dict['root_dir'] is not None
                   and flags_dict['experiment_name'] is not None
                   and flags_dict['train_eval_dir'] is None)
        valid_2 = (flags_dict['root_dir'] is None
                   and flags_dict['experiment_name'] is None
                   and flags_dict['train_eval_dir'] is not None)
        if valid_1 or valid_2:
            return True
        message = ('Exactly both root_dir and experiment_name or only '
                   'train_eval_dir must be specified.')
        raise flags.ValidationError(message)

    flags.register_multi_flags_validator(
        ['root_dir', 'experiment_name', 'train_eval_dir'],
        validate_mutual_exclusion)
    app.run(main)
Example #22
0
flags.DEFINE_bool('carla_localization', False,
                  'True to use perfect localization')


# Certain visualizations are not supported when running in challenge mode.
def unsupported_visualizations_validator(flags_dict):
    return not (flags_dict['visualize_depth_camera']
                or flags_dict['visualize_imu'] or flags_dict['visualize_pose']
                or flags_dict['visualize_prediction'])


flags.register_multi_flags_validator(
    [
        'visualize_depth_camera', 'visualize_imu', 'visualize_pose',
        'visualize_prediction'
    ],
    unsupported_visualizations_validator,
    message='Trying to visualize unsupported_visualization')

CENTER_CAMERA_LOCATION = pylot.utils.Location(0.0, 0.0, 2.0)
CENTER_CAMERA_NAME = 'center_camera'
LANE_CAMERA_LOCATION = pylot.utils.Location(1.3, 0.0, 1.8)
LANE_CAMERA_NAME = 'lane_camera'
TL_CAMERA_NAME = 'traffic_lights_camera'
LEFT_CAMERA_NAME = 'left_camera'
RIGHT_CAMERA_NAME = 'right_camera'


def get_entry_point():
    return 'ERDOSAgent'
Example #23
0
    help='Use Squeeze and Excitation with bias.')

flags.DEFINE_integer(
    'SE_ratio', 2,
    help='Squeeze and Excitation ratio.')

flags.DEFINE_bool(
    'use_swish', False,
    help=('Use Swish activation function inplace of ReLu. '
         'https://arxiv.org/pdf/1710.05941.pdf'))


# TODO(seth): Verify if this is still required.
flags.register_multi_flags_validator(
    ['use_tpu', 'iterations_per_loop', 'summary_steps'],
    lambda flags: (not flags['use_tpu'] or
                   flags['summary_steps'] % flags['iterations_per_loop'] == 0),
    'If use_tpu, summary_steps must be a multiple of iterations_per_loop')

FLAGS = flags.FLAGS


class DualNetwork():
    def __init__(self, save_file):
        self.save_file = save_file
        self.inference_input = None
        self.inference_output = None
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        if FLAGS.use_mgpu_horovod:
            config.gpu_options.visible_device_list = str(hvd.local_rank())
Example #24
0
                   'Fraction of positions to filter from golden chunks,'
                   'default, 1.0 (no filter)')

flags.DEFINE_string('export_path', None,
                    'Where to export the model after training.')

flags.DEFINE_bool('use_bt', False,
                  'Whether to use Bigtable as input.  '
                  '(Only supported with --use_tpu, currently.)')

flags.DEFINE_bool('freeze', False,
                  'Whether to freeze the graph at the end of training.')


flags.register_multi_flags_validator(
    ['use_bt', 'use_tpu'],
    lambda flags: flags['use_tpu'] if flags['use_bt'] else True,
    '`use_bt` flag only valid with `use_tpu` as well')

@flags.multi_flags_validator(
    ['num_examples', 'steps_to_train', 'filter_amount'],
    '`num_examples` requires `steps_to_train==0` and `filter_amount==1.0`')
def _example_flags_validator(flags_dict):
    if not flags_dict['num_examples']:
        return True
    return not flags_dict['steps_to_train'] and flags_dict['filter_amount'] == 1.0

@flags.multi_flags_validator(
    ['use_bt', 'cbt_project', 'cbt_instance', 'cbt_table'],
    message='Cloud Bigtable configuration flags not correct')
def _bt_checker(flags_dict):
    if not flags_dict['use_bt']:
Example #25
0
from absl import flags
from tensorflow import gfile

import shipname

flags.DEFINE_string(
    'base_dir', None, 'Root directory if using local FS as the database.'
    'Leave blank if using bucket_name.')

flags.DEFINE_string(
    'bucket_name', None, 'Bucket name if using GCS as the filesystem DB.'
    'Leave blank is using base_dir.')

flags.register_multi_flags_validator(
    ['base_dir', 'bucket_name'],
    lambda flags: bool(flags['base_dir']) != bool(flags['bucket_name']),
    'Exactly one of --base_dir, --bucket_name must be set!')

FLAGS = flags.FLAGS


def _with_base(*args):
    def inner():
        base_dir = FLAGS.base_dir or 'gs://{}'.format(FLAGS.bucket_name)
        return os.path.join(base_dir, *args)

    return inner


# Functions to compute various important directories, based on FLAGS input.
models_dir = _with_base('models')
Example #26
0
flags.DEFINE_integer('trunk_layers', go.N,
                     'The number of resnet layers in the shared trunk.')

flags.DEFINE_multi_integer(
    'lr_boundaries', [400000, 600000],
    'The number of steps at which the learning rate will decay')

flags.DEFINE_multi_float('lr_rates', [0.01, 0.001, 0.0001],
                         'The different learning rates')

flags.DEFINE_integer('training_seed', 0,
                     'Random seed to use for training and validation')

flags.register_multi_flags_validator(
    ['lr_boundaries', 'lr_rates'],
    lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1,
    'Number of learning rates must be exactly one greater than the number of boundaries'
)

flags.DEFINE_float('l2_strength', 1e-4,
                   'The L2 regularization parameter applied to weights.')

flags.DEFINE_float(
    'value_cost_weight', 1.0,
    'Scalar for value_cost, AGZ paper suggests 1/100 for '
    'supervised learning')

flags.DEFINE_float('sgd_momentum', 0.9,
                   'Momentum parameter for learning rate.')

flags.DEFINE_string(
Example #27
0
    'use_bt', False, 'Whether to use Bigtable as input.  '
    '(Only supported with --use_tpu, currently.)')

flags.DEFINE_bool('freeze', False,
                  'Whether to freeze the graph at the end of training.')

flags.DEFINE_boolean('profile_hvd', False,
                     'Whether to profile horovod based multi-gpu training.')

flags.DEFINE_boolean('use_trt', False,
                     'True to write a GraphDef that uses the TRT runtime')
flags.DEFINE_integer('trt_max_batch_size', None, 'Maximum TRT batch size')
flags.DEFINE_string('trt_precision', 'fp32',
                    'Precision for TRT runtime: fp16, fp32 or int8')
flags.register_multi_flags_validator(
    ['use_trt', 'trt_max_batch_size'],
    lambda flags: not flags['use_trt'] or flags['trt_max_batch_size'],
    'trt_max_batch_size must be set if use_trt is true')

flags.register_multi_flags_validator(
    ['use_bt', 'use_tpu'], lambda flags: flags['use_tpu']
    if flags['use_bt'] else True,
    '`use_bt` flag only valid with `use_tpu` as well')


@flags.multi_flags_validator(
    ['num_examples', 'steps_to_train', 'filter_amount'],
    '`num_examples` requires `steps_to_train==0` and `filter_amount==1.0`')
def _example_flags_validator(flags_dict):
    if not flags_dict['num_examples']:
        return True
    return not flags_dict['steps_to_train'] and flags_dict[
Example #28
0
########################################
# Recording operators.
########################################
flags.DEFINE_string('data_path', 'data/', 'Path where to logged data')
flags.DEFINE_bool('log_detector_output', False,
                  'Enable recording of bbox annotated detector images')
flags.DEFINE_bool('log_traffic_light_detector_output', False,
                  'Enable recording of bbox annotated tl detector images')

# Flag validators.
flags.register_multi_flags_validator(
    [
        'obstacle_detection', 'obstacle_detection_model_paths',
        'obstacle_detection_model_names'
    ],
    lambda flags_dict: (not flags_dict['obstacle_detection'] or
                        (flags_dict['obstacle_detection'] and
                         (len(flags_dict['obstacle_detection_model_paths']) ==
                          len(flags_dict['obstacle_detection_model_names'])))),
    message='--obstacle_detection_model_paths and '
    '--obstacle_detection_model_names must have the same length')


def prediction_validator(flags_dict):
    if flags_dict['prediction']:
        return (flags_dict['obstacle_tracking']
                or flags_dict['perfect_obstacle_tracking'])
    return True


flags.register_multi_flags_validator(
Example #29
0
flags.DEFINE_integer('trunk_layers', go.N,
                     'The number of resnet layers in the shared trunk.')

flags.DEFINE_multi_integer(
    'lr_boundaries', [400000, 600000],
    'The number of steps at which the learning rate will decay')

flags.DEFINE_multi_float('lr_rates', [0.01, 0.001, 0.0001],
                         'The different learning rates')

flags.DEFINE_integer('training_seed', 0,
                     'Random seed to use for training and validation')

flags.register_multi_flags_validator(
    ['lr_boundaries', 'lr_rates'],
    lambda flags: len(flags['lr_boundaries']) == len(flags['lr_rates']) - 1,
    'Number of learning rates must be exactly one greater than the number of boundaries'
)

flags.DEFINE_float('l2_strength', 1e-4,
                   'The L2 regularization parameter applied to weights.')

flags.DEFINE_float(
    'value_cost_weight', 1.0,
    'Scalar for value_cost, AGZ paper suggests 1/100 for '
    'supervised learning')

flags.DEFINE_float('sgd_momentum', 0.9,
                   'Momentum parameter for learning rate.')

flags.DEFINE_string(