コード例 #1
0
 def test_mnist_train_pipelined_processes(self):
     """Generic test on default arguments in training"""
     cmd = "horovodrun -np 2 -H localhost:2 python horovod_popart_mnist.py --num-ipus 2 --pipeline --epochs 3"
     output = self.run_command(cmd, self.cwd, "Accuracy")
     expected_accuracy = [71.26, 78.77, 82.14]
     accuracies = parse_results_with_regex(output,
                                           r".* + Accuracy=+([\d.]+)%")
     verify_model_accuracies(accuracies[0], expected_accuracy,
                             self.accuracy_tolerances)
コード例 #2
0
 def test_mnist_train_sixteen_processes(self):
     """Generic test on default arguments in training"""
     cmd = "horovodrun -np 16 -H localhost:16 python horovod_popart_mnist.py --epochs 3"
     output = self.run_command(cmd, self.cwd, "Accuracy")
     expected_accuracy = [83.75, 83.86, 84.79]
     accuracies = parse_results_with_regex(output,
                                           r".* + Accuracy=+([\d.]+)%")
     verify_model_accuracies(accuracies[0], expected_accuracy,
                             self.accuracy_tolerances)
コード例 #3
0
 def test_pipelining_convergence(self):
     """Run with default settings and check it converges"""
     out = run_pipelining_example({})
     # Get the final loss
     loss_regex = r"loss: ([\d.]+)"
     result = test_util.parse_results_with_regex(out, loss_regex)
     # Get the last loss
     loss = result[0][-1]
     self.assertGreater(loss, 0.001)
     self.assertLess(loss, 0.02)
コード例 #4
0
 def test_mnist_train_multiple_options(self):
     """Generic test on default arguments in training"""
     cmd = "horovodrun -np 4 -H localhost:4 python horovod_popart_mnist.py --epochs 8 --batch-size 64 --batches-per-step 50 --num-ipus 2 --pipeline --log-graph-trace"
     output = self.run_command(cmd, self.cwd, "Accuracy")
     expected_accuracy = [
         69.80, 77.77, 81.21, 83.01, 84.11, 84.76, 85.48, 86.02
     ]
     accuracies = parse_results_with_regex(output,
                                           r".* + Accuracy=+([\d.]+)%")
     verify_model_accuracies(accuracies[0], expected_accuracy,
                             self.accuracy_tolerances)