Esempio n. 1
0
 def test_run_benchmark(self):
     """Tests that run_benchmark() runs successfully."""
     params = benchmark_cnn.make_params(num_batches=10,
                                        variable_update='replicated',
                                        num_gpus=2)
     self._test_run_benchmark(params)
     params = params._replace(hierarchical_copy=True,
                              gradient_repacking=8,
                              num_gpus=8)
     self._test_run_benchmark(params)
def get_var_update_params():
    """Returns params that are used when testing variable updates."""
    return benchmark_cnn.make_params(batch_size=2,
                                     model='test_model',
                                     num_gpus=2,
                                     display_every=1,
                                     num_warmup_batches=0,
                                     num_batches=4,
                                     weight_decay=2**-4,
                                     init_learning_rate=2**-4,
                                     optimizer='sgd')
def get_params(train_dir_name):
    """Returns params that can be used to train."""
    return benchmark_cnn.make_params(batch_size=2,
                                     display_every=1,
                                     init_learning_rate=0.005,
                                     model='trivial',
                                     num_batches=20,
                                     num_gpus=2,
                                     num_warmup_batches=5,
                                     optimizer='sgd',
                                     print_training_accuracy=True,
                                     train_dir=get_temp_dir(train_dir_name),
                                     variable_update='parameter_server',
                                     weight_decay=0)
Esempio n. 4
0
 def testMlPerfCompliance(self):
     string_io = six.StringIO()
     handler = logging.StreamHandler(string_io)
     data_dir = test_util.create_black_and_white_images()
     try:
         mlperf_log.LOGGER.addHandler(handler)
         params = benchmark_cnn.make_params(
             data_dir=data_dir,
             data_name='imagenet',
             batch_size=2,
             num_warmup_batches=0,
             num_batches=2,
             num_eval_batches=3,
             eval_during_training_every_n_steps=1,
             distortions=False,
             weight_decay=0.5,
             optimizer='momentum',
             momentum=0.5,
             stop_at_top_1_accuracy=2.0,
             tf_random_seed=9876,
             ml_perf=True)
         with mlperf.mlperf_logger(use_mlperf_logger=True,
                                   model='resnet50_v1.5'):
             bench_cnn = benchmark_cnn.BenchmarkCNN(
                 params, model=_MlPerfTestModel())
             bench_cnn.run()
         logs = string_io.getvalue().splitlines()
         log_regexes = Counter()
         for log in logs:
             for regex in self.EXPECTED_LOG_REGEXES:
                 if regex.search(log):
                     log_regexes[regex] += 1
         if log_regexes != self.EXPECTED_LOG_REGEXES:
             diff_counter = Counter(log_regexes)
             diff_counter.subtract(self.EXPECTED_LOG_REGEXES)
             differences = []
             for regex in (k for k in diff_counter.keys()
                           if diff_counter[k]):
                 found_count = log_regexes[regex]
                 expected_count = self.EXPECTED_LOG_REGEXES[regex]
                 differences.append(
                     '  For regex %s: Found %d lines matching but '
                     'expected to find %d' %
                     (regex.pattern, found_count, expected_count))
             raise AssertionError(
                 'Logs did not match expected logs. Differences:\n'
                 '%s' % '\n'.join(differences))
     finally:
         mlperf_log.LOGGER.removeHandler(handler)
Esempio n. 5
0
 def setUp(self):
     super(MlPerfComplianceTest, self).setUp()
     benchmark_cnn.setup(benchmark_cnn.make_params())