def testByteSizeLoadFnWithScalar(self): with ops.device( device_setter.replica_device_setter( cluster=self._cluster_spec, ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy( 2, device_setter_lib.byte_size_load_fn))): # Note: we must test the load function as part of the device function # instead of passing u.op to the function directly, because the only # time that the output Tensor has unknown shape for scalars is during # Variable construction. u = variables.Variable(0) self.assertDeviceEqual("/job:ps/task:0", u.device) self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
def testByteSizeLoadFn(self): with ops.device( device_setter.replica_device_setter( cluster=self._cluster_spec, ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy( 2, device_setter_lib.byte_size_load_fn))): u = variables.Variable(array_ops.zeros([2, 2])) v = variables.Variable(array_ops.zeros([2, 1])) w = variables.Variable(array_ops.zeros([2, 2])) x = variables.Variable(array_ops.zeros([1, 3])) a = v + w self.assertDeviceEqual("/job:ps/task:0", u.device) self.assertDeviceEqual("/job:ps/task:0", u.initializer.device) self.assertDeviceEqual("/job:ps/task:1", v.device) self.assertDeviceEqual("/job:ps/task:1", v.initializer.device) self.assertDeviceEqual("/job:ps/task:1", w.device) self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) self.assertDeviceEqual("/job:ps/task:0", x.device) self.assertDeviceEqual("/job:ps/task:0", x.initializer.device) self.assertDeviceEqual("/job:worker", a.device)
def testUniformLoadEqualsRoundRobin(self): def _load_fn(unused_op): return 1 with ops.device( device_setter.replica_device_setter( cluster=_CLUSTER_SPEC, ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy( 2, _load_fn))): u = variables.Variable(array_ops.zeros([2, 2])) v = variables.Variable(array_ops.zeros([2, 1])) w = variables.Variable(array_ops.zeros([2, 2])) x = variables.Variable(array_ops.zeros([1, 3])) a = v + w self.assertDeviceEqual("/job:ps/task:0", u.device) self.assertDeviceEqual("/job:ps/task:0", u.initializer.device) self.assertDeviceEqual("/job:ps/task:1", v.device) self.assertDeviceEqual("/job:ps/task:1", v.initializer.device) self.assertDeviceEqual("/job:ps/task:0", w.device) self.assertDeviceEqual("/job:ps/task:0", w.initializer.device) self.assertDeviceEqual("/job:ps/task:1", x.device) self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) self.assertDeviceEqual("/job:worker", a.device)