def test_run_benchmark(self): """Tests that run_benchmark() runs successfully.""" params = benchmark_cnn.make_params(num_batches=10, variable_update='replicated', num_gpus=2) self._test_run_benchmark(params) params = params._replace(hierarchical_copy=True, gradient_repacking=8, num_gpus=8) self._test_run_benchmark(params)
def get_var_update_params(): """Returns params that are used when testing variable updates.""" return benchmark_cnn.make_params(batch_size=2, model='test_model', num_gpus=2, display_every=1, num_warmup_batches=0, num_batches=4, weight_decay=2**-4, init_learning_rate=2**-4, optimizer='sgd')
def get_params(train_dir_name): """Returns params that can be used to train.""" return benchmark_cnn.make_params(batch_size=2, display_every=1, init_learning_rate=0.005, model='trivial', num_batches=20, num_gpus=2, num_warmup_batches=5, optimizer='sgd', print_training_accuracy=True, train_dir=get_temp_dir(train_dir_name), variable_update='parameter_server', weight_decay=0)
def testMlPerfCompliance(self): string_io = six.StringIO() handler = logging.StreamHandler(string_io) data_dir = test_util.create_black_and_white_images() try: mlperf_log.LOGGER.addHandler(handler) params = benchmark_cnn.make_params( data_dir=data_dir, data_name='imagenet', batch_size=2, num_warmup_batches=0, num_batches=2, num_eval_batches=3, eval_during_training_every_n_steps=1, distortions=False, weight_decay=0.5, optimizer='momentum', momentum=0.5, stop_at_top_1_accuracy=2.0, tf_random_seed=9876, ml_perf=True) with mlperf.mlperf_logger(use_mlperf_logger=True, model='resnet50_v1.5'): bench_cnn = benchmark_cnn.BenchmarkCNN( params, model=_MlPerfTestModel()) bench_cnn.run() logs = string_io.getvalue().splitlines() log_regexes = Counter() for log in logs: for regex in self.EXPECTED_LOG_REGEXES: if regex.search(log): log_regexes[regex] += 1 if log_regexes != self.EXPECTED_LOG_REGEXES: diff_counter = Counter(log_regexes) diff_counter.subtract(self.EXPECTED_LOG_REGEXES) differences = [] for regex in (k for k in diff_counter.keys() if diff_counter[k]): found_count = log_regexes[regex] expected_count = self.EXPECTED_LOG_REGEXES[regex] differences.append( ' For regex %s: Found %d lines matching but ' 'expected to find %d' % (regex.pattern, found_count, expected_count)) raise AssertionError( 'Logs did not match expected logs. Differences:\n' '%s' % '\n'.join(differences)) finally: mlperf_log.LOGGER.removeHandler(handler)
def setUp(self): super(MlPerfComplianceTest, self).setUp() benchmark_cnn.setup(benchmark_cnn.make_params())