コード例 #1
0
 def testReplicatedRealData(self):
     test_name = 'testReplicatedRealData'
     imagenet_dir = os.path.join(platforms_util.get_test_data_dir(),
                                 'fake_tf_record_data')
     params = test_util.get_params(test_name)._replace(
         variable_update='distributed_replicated',
         data_dir=imagenet_dir,
         data_name='imagenet')
     self._test_distributed(test_name, 2, 2, params)
コード例 #2
0
 def testFp16WithFp16Vars(self):
     test_name = 'testFp16WithFp16Vars'
     params = test_util.get_params(test_name)._replace(use_fp16=True,
                                                       fp16_vars=True,
                                                       fp16_loss_scale=1.)
     self._test_distributed(test_name, 2, 2, params)
コード例 #3
0
 def testFp16WithFp32Vars(self):
     test_name = 'testFp16WithFp32Vars'
     params = test_util.get_params(test_name)._replace(use_fp16=True,
                                                       fp16_vars=False)
     self._test_distributed(test_name, 2, 2, params)
コード例 #4
0
 def testBatchGroupSize(self):
     test_name = 'testBatchGroupSize'
     params = test_util.get_params(test_name)._replace(batch_group_size=4,
                                                       num_batches=100,
                                                       num_warmup_batches=5)
     self._test_distributed(test_name, 2, 2, params)
コード例 #5
0
 def testSingleGpu(self):
     test_name = 'testSingleGpu'
     params = test_util.get_params(test_name)._replace(num_gpus=1)
     self._test_distributed(test_name, 2, 2, params)
コード例 #6
0
 def testNoCrossReplicaSyncParameterServerStaged(self):
     test_name = 'testNoCrossReplicaSyncParameterServerStaged'
     params = test_util.get_params(test_name)._replace(
         staged_vars=True, cross_replica_sync=False)
     self._test_distributed(test_name, 2, 2, params)
コード例 #7
0
 def testMomentumReplicated(self):
     test_name = 'testMomentumReplicated'
     params = test_util.get_params(test_name)._replace(
         optimizer='momentum', variable_update='distributed_replicated')
     self._test_distributed(test_name, 2, 2, params)
コード例 #8
0
 def testParameterServerStaged(self):
     test_name = 'testParameterServerStaged'
     params = test_util.get_params(test_name)._replace(staged_vars=True)
     self._test_distributed(test_name, 2, 2, params)
コード例 #9
0
 def testNoPrintTrainingAccuracy(self):
     test_name = 'testNoPrintTrainingAccuracy'
     params = test_util.get_params(test_name)._replace(
         print_training_accuracy=False)
     self._test_distributed(test_name, 2, 2, params)
コード例 #10
0
 def testThreeWorkersOnePs(self):
     test_name = 'testThreeWorkersOnePs'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 3, 1, params)
コード例 #11
0
 def testOneWorkerThreePses(self):
     test_name = 'testOneWorkerThreePses'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 1, 3, params)
コード例 #12
0
 def testThreeWorkersAndPses(self):
     test_name = 'testThreeWorkersAndPses'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 3, 3, params)
コード例 #13
0
 def testSingleWorkerAndPs(self):
     test_name = 'testSingleWorkerAndPs'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 1, 1, params)
コード例 #14
0
 def testForwardOnly(self):
     test_name = 'testForwardOnly'
     params = test_util.get_params(test_name)._replace(forward_only=True)
     # Evaluation is not supported with --forward_only, so we set skip='eval'.
     self._test_distributed(test_name, 2, 2, params, skip='eval')
コード例 #15
0
 def testAllReducePscpuXring(self):
     test_name = 'testAllReducePscpuXring'
     flags_dict = test_util.get_params(test_name)._replace(
         variable_update='distributed_all_reduce',
         all_reduce_spec='pscpu:2k:xring')
     self._test_distributed(test_name, 2, 0, flags_dict, num_controllers=1)
コード例 #16
0
 def testFp16Replicated(self):
     test_name = 'testFp16Replicated'
     params = test_util.get_params(test_name)._replace(
         use_fp16=True, variable_update='distributed_replicated')
     self._test_distributed(test_name, 2, 2, params)
コード例 #17
0
 def testRmspropParameterServer(self):
     test_name = 'testRmspropParameterServer'
     params = test_util.get_params(test_name)._replace(optimizer='rmsprop')
     self._test_distributed(test_name, 2, 2, params)
コード例 #18
0
 def testParameterServer(self):
     test_name = 'testParameterServer'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 2, 2, params)