def testParameterServer(self):
     test_name = 'testParameterServer'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 2, 2, params)
 def testCpuAsLocalParamDevice(self):
   params = test_util.get_params('testCpuAsLocalParamDevice')._replace(
       local_parameter_device='cpu')
   self._train_and_eval_local(params)
 def testCpuAsDevice(self):
   params = test_util.get_params('testCpuAsDevice')._replace(
       device='cpu', data_format='NHWC')  # NHWC required when --device=cpu
   self._train_and_eval_local(params)
 def testReplicated(self):
   params = test_util.get_params('testReplicated')._replace(
       variable_update='replicated')
   self._train_and_eval_local(params)
 def testForwardOnly(self):
   params = test_util.get_params('testForwardOnly')._replace(forward_only=True)
   # Evaluation is not supported with --forward_only, so we set skip='eval'.
   self._train_and_eval_local(params, skip='eval')
 def testAlexnet(self):
   params = test_util.get_params('testAlexnet')._replace(
       num_batches=30, init_learning_rate=0.01, model='alexnet')
   self._train_and_eval_local(params)
 def testParameterServer(self):
   params = test_util.get_params('testParameterServer')
   self._train_and_eval_local(params)
 def testOneWorkerThreePses(self):
     test_name = 'testOneWorkerThreePses'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 1, 3, params)
 def testThreeWorkersOnePs(self):
     test_name = 'testThreeWorkersOnePs'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 3, 1, params)
 def testSingleWorkerAndPs(self):
     test_name = 'testSingleWorkerAndPs'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 1, 1, params)
 def testThreeWorkersAndPses(self):
     test_name = 'testThreeWorkersAndPses'
     params = test_util.get_params(test_name)
     self._test_distributed(test_name, 3, 3, params)
 def testForwardOnly(self):
     test_name = 'testForwardOnly'
     params = test_util.get_params(test_name)._replace(forward_only=True)
     # Evaluation is not supported with --forward_only, so we set skip='eval'.
     self._test_distributed(test_name, 2, 2, params, skip='eval')
 def testAllReducePscpuXring(self):
     test_name = 'testAllReducePscpuXring'
     flags_dict = test_util.get_params(test_name)._replace(
         variable_update='distributed_all_reduce',
         all_reduce_spec='pscpu:2k:xring')
     self._test_distributed(test_name, 2, 0, flags_dict, num_controllers=1)
 def testParameterServerStaged(self):
     test_name = 'testParameterServerStaged'
     params = test_util.get_params(test_name)._replace(staged_vars=True)
     self._test_distributed(test_name, 2, 2, params)
 def testFp16WithFp16Vars(self):
     test_name = 'testFp16WithFp16Vars'
     params = test_util.get_params(test_name)._replace(use_fp16=True,
                                                       fp16_vars=True,
                                                       fp16_loss_scale=1.)
     self._test_distributed(test_name, 2, 2, params)
 def testNoPrintTrainingAccuracy(self):
     test_name = 'testNoPrintTrainingAccuracy'
     params = test_util.get_params(test_name)._replace(
         print_training_accuracy=False)
     self._test_distributed(test_name, 2, 2, params)
 def testFp16Replicated(self):
     test_name = 'testFp16Replicated'
     params = test_util.get_params(test_name)._replace(
         use_fp16=True, variable_update='distributed_replicated')
     self._test_distributed(test_name, 2, 2, params)
 def testRmspropParameterServer(self):
     test_name = 'testRmspropParameterServer'
     params = test_util.get_params(test_name)._replace(optimizer='rmsprop')
     self._test_distributed(test_name, 2, 2, params)
 def testNoPrintAccuracy(self):
   params = test_util.get_params('testNoPrintAccuracy')._replace(
       print_training_accuracy=False)
   self._train_and_eval_local(params)
 def testMomentumReplicated(self):
     test_name = 'testMomentumReplicated'
     params = test_util.get_params(test_name)._replace(
         optimizer='momentum', variable_update='distributed_replicated')
     self._test_distributed(test_name, 2, 2, params)
 def testParameterServerStaged(self):
   params = test_util.get_params('testParameterServerStaged')._replace(
       staged_vars=True)
   self._train_and_eval_local(params)
 def testNoCrossReplicaSyncParameterServerStaged(self):
     test_name = 'testNoCrossReplicaSyncParameterServerStaged'
     params = test_util.get_params(test_name)._replace(
         staged_vars=True, cross_replica_sync=False)
     self._test_distributed(test_name, 2, 2, params)
 def testIndependent(self):
   params = test_util.get_params('testIndependent')._replace(
       variable_update='independent')
   self._train_and_eval_local(params)
 def testSingleGpu(self):
     test_name = 'testSingleGpu'
     params = test_util.get_params(test_name)._replace(num_gpus=1)
     self._test_distributed(test_name, 2, 2, params)
 def testNoDistortions(self):
   params = test_util.get_params('testNoDistortions')._replace(
       distortions=False)
   self._train_and_eval_local(params)
 def testBatchGroupSize(self):
     test_name = 'testBatchGroupSize'
     params = test_util.get_params(test_name)._replace(batch_group_size=4,
                                                       num_batches=100,
                                                       num_warmup_batches=5)
     self._test_distributed(test_name, 2, 2, params)
 def testNHWC(self):
   params = test_util.get_params('testNHWC')._replace(data_format='NHWC')
   self._train_and_eval_local(params)
 def testFp16WithFp32Vars(self):
     test_name = 'testFp16WithFp32Vars'
     params = test_util.get_params(test_name)._replace(use_fp16=True,
                                                       fp16_vars=False)
     self._test_distributed(test_name, 2, 2, params)
 def testMomentumParameterServer(self):
   params = test_util.get_params('testMomentumParameterServer')._replace(
       optimizer='momentum', momentum=0.8)
   self._train_and_eval_local(params)
Exemple #30
0
 def testCifar10SyntheticData(self):
     params = test_util.get_params('testCifar10SyntheticData')._replace(
         data_name='cifar10')
     self._train_and_eval_local(params)