示例#1
0
 def test_validating_dataset_input_tensors_with_shape_mismatch(
         self, distribution):
     with self.cached_session():
         a = constant_op.constant([1, 2], shape=(1, 2))
         b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
         x = values.DistributedValues((a, b))
         y = values.DistributedValues((a, a))
         # Removed device and input tensor shape details from the error message
         # since the order of the device and the corresponding input tensor shape
         # is not deterministic over different runs.
         with self.assertRaisesRegex(
                 ValueError, 'Input tensor shapes do not match for '
                 'distributed tensor inputs '
                 'DistributedValues:.+'):
             with distribution.scope():
                 distributed_training_utils.validate_distributed_dataset_inputs(
                     distribution, x, y)
 def test_validating_dataset_input_tensors_with_dtype_mismatch(
     self, distribution):
   with self.cached_session():
     a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
     b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
     device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0'))
     x = values.DistributedValues(device_map, (a, b))
     y = values.DistributedValues(device_map, (a, a))
     # Removed device and input tensor dtype details from the error message
     # since the order of the device and the corresponding input tensor dtype
     # is not deterministic over different runs.
     with self.assertRaisesRegexp(
         ValueError, 'Input tensor dtypes do not match for '
         'distributed tensor inputs '
         'DistributedValues:.+'):
       with distribution.scope():
         distributed_training_utils.validate_distributed_dataset_inputs(
             distribution, x, y)
示例#3
0
 def testGetEager(self):
   with ops.device("/device:CPU:0"):
     one = constant_op.constant(1)
     two = constant_op.constant(2)
     v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
     self.assertEqual(two, v.get("/device:GPU:0"))
     self.assertEqual(one, v.get())
     with self.assertRaises(ValueError):
       self.assertIsNone(v.get("/device:GPU:2"))
示例#4
0
 def testGetEager(self):
   with ops.device("/device:CPU:0"):
     one = constant_op.constant(1)
     two = constant_op.constant(2)
     device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
     v = values.DistributedValues(device_map, (one, two))
     self.assertEqual(two, v.get("/device:GPU:0"))
     self.assertEqual(one, v.get())
     with self.assertRaises(ValueError):
       self.assertIsNone(v.get("/device:GPU:2"))
示例#5
0
 def testIsTensorLikeWithAConstant(self):
   with context.graph_mode(), \
        ops.Graph().as_default(), \
        ops.device("/device:CPU:0"):
     one = constant_op.constant(1)
     two = 2.0
     v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
     self.assertEqual(two, v.get("/device:GPU:0"))
     self.assertEqual(one, v.get())
     self.assertFalse(v.is_tensor_like)
     self.assertFalse(tensor_util.is_tensor(v))
示例#6
0
 def testGetGraph(self):
   with context.graph_mode(), \
       ops.Graph().as_default(), \
       ops.device("/device:CPU:0"):
     one = constant_op.constant(1)
     two = constant_op.constant(2)
     v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
     self.assertEqual(two, v.get("/device:GPU:0"))
     self.assertEqual(one, v.get())
     with self.assertRaises(ValueError):
       self.assertIsNone(v.get("/device:GPU:2"))
示例#7
0
 def testIsTensorLike(self):
   with context.graph_mode(), \
        ops.Graph().as_default(), \
        ops.device("/device:CPU:0"):
     one = constant_op.constant(1)
     two = constant_op.constant(2)
     device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0"))
     v = values.DistributedValues(device_map, (one, two))
     self.assertEqual(two, v.get("/device:GPU:0"))
     self.assertEqual(one, v.get())
     self.assertTrue(v.is_tensor_like)
     self.assertTrue(tensor_util.is_tensor(v))
示例#8
0
  def testNonMatchingVariableCreation(self, distribution):
    self.skipTest("b/123075960")

    def model_fn(name):
      v = variable_scope.variable(1.0, name=name)
      ds_context.get_replica_context().merge_call(lambda _: _)
      return v

    with distribution.scope():
      names = values.DistributedValues(("foo", "bar"))
      with self.assertRaises(RuntimeError):
        _ = distribution.extended.call_for_each_replica(model_fn, args=(names,))