コード例 #1
0
    def testVariableWithVariableDeviceChooser(self):

        with tf.Graph().as_default():
            device_fn = variables.VariableDeviceChooser()
            with scopes.arg_scope([variables.global_step], device=device_fn):
                gs = variables.global_step()
                gs2 = variables.global_step()
                self.assertEquals(gs, gs2)
                self.assertDeviceEqual(gs.device, 'cpu:0')
                self.assertDeviceEqual(gs.initial_value.device, gs.device)
                self.assertDeviceEqual(gs2.device, 'cpu:0')
                self.assertDeviceEqual(gs2.initial_value.device, gs2.device)
コード例 #2
0
 def testReplicaDeviceSetter(self):
     device_fn = tf.train.replica_device_setter(2)
     with tf.Graph().as_default():
         with scopes.arg_scope([variables.global_step], device=device_fn):
             gs = variables.global_step()
             gs2 = variables.global_step()
             self.assertEquals(gs, gs2)
             self.assertDeviceEqual(gs.device, '/job:ps/task:0')
             self.assertDeviceEqual(gs.initial_value.device,
                                    '/job:ps/task:0')
             self.assertDeviceEqual(gs2.device, '/job:ps/task:0')
             self.assertDeviceEqual(gs2.initial_value.device,
                                    '/job:ps/task:0')
コード例 #3
0
    def testDeviceFn(self):
        class DevFn(object):
            def __init__(self):
                self.counter = -1

            def __call__(self, op):
                self.counter += 1
                return '/cpu:%d' % self.counter

        with tf.Graph().as_default():
            with scopes.arg_scope([variables.global_step], device=DevFn()):
                gs = variables.global_step()
                gs2 = variables.global_step()
            self.assertDeviceEqual(gs.device, '/cpu:0')
            self.assertEquals(gs, gs2)
            self.assertDeviceEqual(gs2.device, '/cpu:0')
コード例 #4
0
 def testDevice(self):
     with tf.Graph().as_default():
         with scopes.arg_scope([variables.global_step], device='/gpu:0'):
             gs = variables.global_step()
         self.assertDeviceEqual(gs.device, '/gpu:0')
コード例 #5
0
 def testStable(self):
     with tf.Graph().as_default():
         gs = variables.global_step()
         gs2 = variables.global_step()
         self.assertTrue(gs is gs2)