def testVariableWithVariableDeviceChooser(self): with tf.Graph().as_default(): device_fn = variables.VariableDeviceChooser() with scopes.arg_scope([variables.global_step], device=device_fn): gs = variables.global_step() gs2 = variables.global_step() self.assertEquals(gs, gs2) self.assertDeviceEqual(gs.device, 'cpu:0') self.assertDeviceEqual(gs.initial_value.device, gs.device) self.assertDeviceEqual(gs2.device, 'cpu:0') self.assertDeviceEqual(gs2.initial_value.device, gs2.device)
def testReplicaDeviceSetter(self): device_fn = tf.train.replica_device_setter(2) with tf.Graph().as_default(): with scopes.arg_scope([variables.global_step], device=device_fn): gs = variables.global_step() gs2 = variables.global_step() self.assertEquals(gs, gs2) self.assertDeviceEqual(gs.device, '/job:ps/task:0') self.assertDeviceEqual(gs.initial_value.device, '/job:ps/task:0') self.assertDeviceEqual(gs2.device, '/job:ps/task:0') self.assertDeviceEqual(gs2.initial_value.device, '/job:ps/task:0')
def testDeviceFn(self): class DevFn(object): def __init__(self): self.counter = -1 def __call__(self, op): self.counter += 1 return '/cpu:%d' % self.counter with tf.Graph().as_default(): with scopes.arg_scope([variables.global_step], device=DevFn()): gs = variables.global_step() gs2 = variables.global_step() self.assertDeviceEqual(gs.device, '/cpu:0') self.assertEquals(gs, gs2) self.assertDeviceEqual(gs2.device, '/cpu:0')
def testDevice(self): with tf.Graph().as_default(): with scopes.arg_scope([variables.global_step], device='/gpu:0'): gs = variables.global_step() self.assertDeviceEqual(gs.device, '/gpu:0')
def testStable(self): with tf.Graph().as_default(): gs = variables.global_step() gs2 = variables.global_step() self.assertTrue(gs is gs2)