Example #1
0
    def testVariableWithVariableDeviceChooser(self):

        with tf.Graph().as_default():
            device_fn = variables.VariableDeviceChooser(
                num_parameter_servers=2)
            with scopes.arg_scope([variables.variable], device=device_fn):
                a = variables.variable('a', [])
                b = variables.variable('b', [])
                c = variables.variable('c', [], device='cpu:12')
                d = variables.variable('d', [])
                with tf.device('cpu:99'):
                    e_init = tf.constant(12)
                e = variables.variable('e', initializer=e_init)
            # The values below highlight how the VariableDeviceChooser puts initial
            # values on the same device as the variable job.
            self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
            self.assertDeviceEqual(a.initial_value.device, a.device)
            self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
            self.assertDeviceEqual(b.initial_value.device, b.device)
            self.assertDeviceEqual(c.device, '/cpu:12')
            self.assertDeviceEqual(c.initial_value.device, c.device)
            self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
            self.assertDeviceEqual(d.initial_value.device, d.device)
            self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
            self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
Example #2
0
    def testVariableWithVariableDeviceChooser(self):

        with tf.Graph().as_default():
            device_fn = variables.VariableDeviceChooser()
            with scopes.arg_scope([variables.global_step], device=device_fn):
                gs = variables.global_step()
                gs2 = variables.global_step()
                self.assertEquals(gs, gs2)
                self.assertDeviceEqual(gs.device, 'cpu:0')
                self.assertDeviceEqual(gs.initial_value.device, gs.device)
                self.assertDeviceEqual(gs2.device, 'cpu:0')
                self.assertDeviceEqual(gs2.initial_value.device, gs2.device)