Exemple #1
0
 def testByteSizeLoadFnWithScalar(self):
   with ops.device(
       device_setter.replica_device_setter(
           cluster=self._cluster_spec,
           ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
               2, device_setter_lib.byte_size_load_fn))):
     # Note: we must test the load function as part of the device function
     # instead of passing u.op to the function directly, because the only
     # time that the output Tensor has unknown shape for scalars is during
     # Variable construction.
     u = variables.Variable(0)
     self.assertDeviceEqual("/job:ps/task:0", u.device)
     self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
Exemple #2
0
 def testByteSizeLoadFn(self):
   with ops.device(
       device_setter.replica_device_setter(
           cluster=self._cluster_spec,
           ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
               2, device_setter_lib.byte_size_load_fn))):
     u = variables.Variable(array_ops.zeros([2, 2]))
     v = variables.Variable(array_ops.zeros([2, 1]))
     w = variables.Variable(array_ops.zeros([2, 2]))
     x = variables.Variable(array_ops.zeros([1, 3]))
     a = v + w
     self.assertDeviceEqual("/job:ps/task:0", u.device)
     self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
     self.assertDeviceEqual("/job:ps/task:1", v.device)
     self.assertDeviceEqual("/job:ps/task:1", v.initializer.device)
     self.assertDeviceEqual("/job:ps/task:1", w.device)
     self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
     self.assertDeviceEqual("/job:ps/task:0", x.device)
     self.assertDeviceEqual("/job:ps/task:0", x.initializer.device)
     self.assertDeviceEqual("/job:worker", a.device)
    def testUniformLoadEqualsRoundRobin(self):
        def _load_fn(unused_op):
            return 1

        with ops.device(
                device_setter.replica_device_setter(
                    cluster=_CLUSTER_SPEC,
                    ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
                        2, _load_fn))):
            u = variables.Variable(array_ops.zeros([2, 2]))
            v = variables.Variable(array_ops.zeros([2, 1]))
            w = variables.Variable(array_ops.zeros([2, 2]))
            x = variables.Variable(array_ops.zeros([1, 3]))
            a = v + w
            self.assertDeviceEqual("/job:ps/task:0", u.device)
            self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
            self.assertDeviceEqual("/job:ps/task:1", v.device)
            self.assertDeviceEqual("/job:ps/task:1", v.initializer.device)
            self.assertDeviceEqual("/job:ps/task:0", w.device)
            self.assertDeviceEqual("/job:ps/task:0", w.initializer.device)
            self.assertDeviceEqual("/job:ps/task:1", x.device)
            self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
            self.assertDeviceEqual("/job:worker", a.device)