def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') indices = search_space_utils.parse_list(FLAGS.indices, int) ssd = FLAGS.ssd cost = mobile_cost_model.estimate_cost(indices, ssd) print('estimated cost: {:f}'.format(cost))
def test_parse_list(self): self.assertEqual([], search_space_utils.parse_list('', str)) self.assertEqual([], search_space_utils.parse_list(' \t', str)) self.assertEqual(['hello'], search_space_utils.parse_list('hello', str)) self.assertEqual(['he', 'lo'], search_space_utils.parse_list('he:lo', str)) self.assertEqual([42], search_space_utils.parse_list('42', int)) self.assertEqual([4, 2], search_space_utils.parse_list('4:2', int)) self.assertEqual([1, 2, 3], search_space_utils.parse_list('1:2:3', int)) self.assertAllClose([1.5, 2.5], search_space_utils.parse_list('1.5:2.5', float))
def main(argv): del argv # Unused. if FLAGS.use_held_out_test_set: default_epochs = 360 train_dataset_size = fast_imagenet_input.dataset_size_for_mode('train') else: default_epochs = 90 train_dataset_size = fast_imagenet_input.dataset_size_for_mode('l2l_train') epochs = FLAGS.epochs if epochs is None: epochs = default_epochs weight_decay = FLAGS.weight_decay if weight_decay is None: weight_decay = 3e-5 if FLAGS.use_bfloat16 else 4e-5 dropout_rate = FLAGS.dropout_rate if dropout_rate is None: # Select a dropout rate automatically. For MnasNet-sized models, dropout # substantially improves accuracy when training for 360 epochs, but we # haven't investigated whether it helps when training for 90 epochs. We # currently enable dropout only for long training runs. if epochs < 150: # Disable dropout when training for less than 150 epochs. dropout_rate = 0 elif FLAGS.ssd in mobile_search_space_v3.MOBILENET_V3_LIKE_SSDS: # MobileNetV3-based search space, training for at least 150 epochs. dropout_rate = 0.25 else: # MobileNetV2-based search space, training for at least 150 epochs. dropout_rate = 0.15 max_global_step = train_dataset_size * epochs // FLAGS.train_batch_size params = { 'checkpoint_dir': FLAGS.checkpoint_dir, 'dataset_dir': FLAGS.dataset_dir, 'learning_rate': FLAGS.base_learning_rate * FLAGS.train_batch_size / 256, 'tpu': FLAGS.tpu, 'tpu_zone': FLAGS.tpu_zone, 'gcp_project': FLAGS.gcp_project, 'momentum': FLAGS.momentum, 'weight_decay': weight_decay, 'max_global_step': max_global_step, 'warmup_steps': int(FLAGS.warmup_steps_fraction * max_global_step), 'tpu_iterations_per_loop': FLAGS.tpu_iterations_per_loop, 'train_batch_size': FLAGS.train_batch_size, 'eval_batch_size': FLAGS.eval_batch_size, 'indices': search_space_utils.parse_list(FLAGS.indices, int), 'use_held_out_test_set': FLAGS.use_held_out_test_set, 'use_bfloat16': FLAGS.use_bfloat16, 'filters_multiplier': FLAGS.filters_multiplier, 'dropout_rate': dropout_rate, 'path_dropout_rate': FLAGS.path_dropout_rate, 'ssd': FLAGS.ssd, } if FLAGS.mode in ['train', 'train_and_eval']: _write_params_to_checkpoint_dir(params) run_model(params, FLAGS.mode)