def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    indices = search_space_utils.parse_list(FLAGS.indices, int)
    ssd = FLAGS.ssd
    cost = mobile_cost_model.estimate_cost(indices, ssd)
    print('estimated cost: {:f}'.format(cost))
 def test_parse_list(self):
     self.assertEqual([], search_space_utils.parse_list('', str))
     self.assertEqual([], search_space_utils.parse_list('   \t', str))
     self.assertEqual(['hello'],
                      search_space_utils.parse_list('hello', str))
     self.assertEqual(['he', 'lo'],
                      search_space_utils.parse_list('he:lo', str))
     self.assertEqual([42], search_space_utils.parse_list('42', int))
     self.assertEqual([4, 2], search_space_utils.parse_list('4:2', int))
     self.assertEqual([1, 2, 3],
                      search_space_utils.parse_list('1:2:3', int))
     self.assertAllClose([1.5, 2.5],
                         search_space_utils.parse_list('1.5:2.5', float))
def main(argv):
  del argv  # Unused.

  if FLAGS.use_held_out_test_set:
    default_epochs = 360
    train_dataset_size = fast_imagenet_input.dataset_size_for_mode('train')
  else:
    default_epochs = 90
    train_dataset_size = fast_imagenet_input.dataset_size_for_mode('l2l_train')

  epochs = FLAGS.epochs
  if epochs is None:
    epochs = default_epochs

  weight_decay = FLAGS.weight_decay
  if weight_decay  is None:
    weight_decay = 3e-5 if FLAGS.use_bfloat16 else 4e-5

  dropout_rate = FLAGS.dropout_rate
  if dropout_rate is None:
    # Select a dropout rate automatically. For MnasNet-sized models, dropout
    # substantially improves accuracy when training for 360 epochs, but we
    # haven't investigated whether it helps when training for 90 epochs. We
    # currently enable dropout only for long training runs.
    if epochs < 150:
      # Disable dropout when training for less than 150 epochs.
      dropout_rate = 0
    elif FLAGS.ssd in mobile_search_space_v3.MOBILENET_V3_LIKE_SSDS:
      # MobileNetV3-based search space, training for at least 150 epochs.
      dropout_rate = 0.25
    else:
      # MobileNetV2-based search space, training for at least 150 epochs.
      dropout_rate = 0.15

  max_global_step = train_dataset_size * epochs // FLAGS.train_batch_size

  params = {
      'checkpoint_dir': FLAGS.checkpoint_dir,
      'dataset_dir': FLAGS.dataset_dir,
      'learning_rate': FLAGS.base_learning_rate * FLAGS.train_batch_size / 256,
      'tpu': FLAGS.tpu,
      'tpu_zone': FLAGS.tpu_zone,
      'gcp_project': FLAGS.gcp_project,
      'momentum': FLAGS.momentum,
      'weight_decay': weight_decay,
      'max_global_step': max_global_step,
      'warmup_steps': int(FLAGS.warmup_steps_fraction * max_global_step),
      'tpu_iterations_per_loop': FLAGS.tpu_iterations_per_loop,
      'train_batch_size': FLAGS.train_batch_size,
      'eval_batch_size': FLAGS.eval_batch_size,
      'indices': search_space_utils.parse_list(FLAGS.indices, int),
      'use_held_out_test_set': FLAGS.use_held_out_test_set,
      'use_bfloat16': FLAGS.use_bfloat16,
      'filters_multiplier': FLAGS.filters_multiplier,
      'dropout_rate': dropout_rate,
      'path_dropout_rate': FLAGS.path_dropout_rate,
      'ssd': FLAGS.ssd,
  }

  if FLAGS.mode in ['train', 'train_and_eval']:
    _write_params_to_checkpoint_dir(params)

  run_model(params, FLAGS.mode)