def testEraseFirstGrad(self): with backprop.GradientTape(persistent=True) as tape: m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) k2 = constant_op.constant(2.0) v = constant_op.constant(11.0) v2 = constant_op.constant(22.0) tape.watch(v) tape.watch(v2) m = map_ops.tensor_map_insert(m, k, v) l = map_ops.tensor_map_lookup(m, k, v.dtype) m = map_ops.tensor_map_insert(m, k2, v2) m, e = map_ops.tensor_map_erase(m, k, v.dtype) l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype) self.assertAllClose(l2, v2) self.assertAllClose(e, v) g = tape.gradient(l * 5, v) self.assertAllEqual(g, 5) g2 = tape.gradient(l2 * 6, v2) self.assertAllEqual(g2, 6) g3 = tape.gradient(e * 7, v) self.assertAllEqual(g3, 7) m, e2 = map_ops.tensor_map_erase(m, k2, v2.dtype) g4 = tape.gradient(e2 * 8, v2) self.assertAllEqual(g4, 8) del tape
def testTensorMapEraseFromEmptyMapFails(self): m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) with self.assertRaisesRegex(errors.InvalidArgumentError, "Trying to erase non-existent item."): m = map_ops.tensor_map_erase(m, k, dtypes.float32) self.evaluate(m)
def testStringKeyGrad(self): with backprop.GradientTape(persistent=True) as tape: m = map_ops.empty_tensor_map() k = constant_op.constant("key") k2 = constant_op.constant("key2") v = constant_op.constant(2.0) v2 = constant_op.constant(22.0) tape.watch(v) tape.watch(v2) m = map_ops.tensor_map_insert(m, k, v) m = map_ops.tensor_map_insert(m, k2, v2) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 2) # Test lookup and gradient. l = map_ops.tensor_map_lookup(m, k, v.dtype) self.assertAllClose(l, v) self.assertAllClose(tape.gradient(l * 5, v), 5) # Test replace and gradient. m = map_ops.tensor_map_insert(m, k, v2) l2 = map_ops.tensor_map_lookup(m, k, v2.dtype) self.assertAllClose(l2, v2) g = tape.gradient(l2 * 6, v2) self.assertAllEqual(g, 6) # Test erase, has key, and gradient. m = map_ops.tensor_map_erase(m, k, v2.dtype) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 1) h = map_ops.tensor_map_has_key(m, k) self.assertAllEqual(h, False) l = map_ops.tensor_map_lookup(m, k2, v2.dtype) g2 = tape.gradient(l * 6, v2) self.assertAllEqual(g2, 6) del tape
def testStringKeyGrad(self): with backprop.GradientTape(persistent=True) as tape: m = map_ops.empty_tensor_map() k = constant_op.constant("key") k2 = constant_op.constant("key2") v = constant_op.constant(2.0) v2 = constant_op.constant(22.0) tape.watch(v2) m = map_ops.tensor_map_insert(m, k2, v2) m = map_ops.tensor_map_insert(m, k, v) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 2) l = map_ops.tensor_map_lookup(m, k, v.dtype) self.assertAllClose(l, v) m = map_ops.tensor_map_insert(m, k, v2) l2 = map_ops.tensor_map_lookup(m, k, v2.dtype) self.assertAllClose(l2, v2) g = tape.gradient(l2 * 5, v2) self.assertAllEqual(g, 5) m, e = map_ops.tensor_map_erase(m, k, v2.dtype) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 1) self.assertAllClose(e, v2) g2 = tape.gradient(e * 6, v2) self.assertAllEqual(g2, 6) del tape
def testTensorMapErase(self): m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) v = constant_op.constant(2.0) m = map_ops.tensor_map_insert(m, k, v) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 1) m = map_ops.tensor_map_erase(m, k, v.dtype) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 0)
def testTensorMapEraseMissingKeyFails(self): m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) k2 = constant_op.constant(2.0) v = constant_op.constant(2.0) m = map_ops.tensor_map_insert(m, k2, v) with self.assertRaisesRegex(errors.InvalidArgumentError, "Trying to erase non-existent item."): m, e = map_ops.tensor_map_erase(m, k, dtypes.float32) self.evaluate(e)
def testVectorValue(self): m = map_ops.empty_tensor_map() k = constant_op.constant([1.0, 2.0]) v = constant_op.constant([11.0, 22.0]) m = map_ops.tensor_map_insert(m, k, v) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 1) l = map_ops.tensor_map_lookup(m, k, v.dtype) self.assertAllEqual(l, v) m, e = map_ops.tensor_map_erase(m, k, v.dtype) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 0) self.assertAllClose(e, v)
def testStringValue(self): m = map_ops.empty_tensor_map() k = constant_op.constant("key") v = constant_op.constant("value") k2 = constant_op.constant(1.0) v2 = constant_op.constant(2.0) m = map_ops.tensor_map_insert(m, k, v) m = map_ops.tensor_map_insert(m, k2, v2) l = map_ops.tensor_map_lookup(m, k, v.dtype) self.assertAllEqual(l, v) l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype) self.assertAllClose(l2, v2) m, e = map_ops.tensor_map_erase(m, k, v.dtype) self.assertAllEqual(e, v)
def testVectorValue(self): m = map_ops.empty_tensor_map() k = constant_op.constant([1.0, 2.0]) v = constant_op.constant([11.0, 22.0]) # Test insert and lookup. m = map_ops.tensor_map_insert(m, k, v) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 1) l = map_ops.tensor_map_lookup(m, k, v.dtype) self.assertAllEqual(l, v) # Test erase and has key. m = map_ops.tensor_map_erase(m, k, v.dtype) s = map_ops.tensor_map_size(m) self.assertAllEqual(s, 0) self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False)
def testEraseInsertComposedGrad(self): with backprop.GradientTape(persistent=True) as tape: m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) k2 = constant_op.constant(2.0) v = constant_op.constant(11.0) v2 = constant_op.constant(22.0) tape.watch(v) tape.watch(v2) m = map_ops.tensor_map_insert(m, k, v) l = map_ops.tensor_map_lookup(m, k, v.dtype) m = map_ops.tensor_map_erase(m, k, v.dtype) m = map_ops.tensor_map_insert(m, k2, l) l2 = map_ops.tensor_map_lookup(m, k2, l.dtype) g = tape.gradient(l2 * 5, v) self.assertAllEqual(g, 5) del tape
def testStringKeyValue(self): m = map_ops.empty_tensor_map() k = constant_op.constant("key") v = constant_op.constant("value") k2 = constant_op.constant(1.0) v2 = constant_op.constant(2.0) # Test insert and lookup on string key-value pair. m = map_ops.tensor_map_insert(m, k, v) m = map_ops.tensor_map_insert(m, k2, v2) l = map_ops.tensor_map_lookup(m, k, v.dtype) self.assertAllEqual(l, v) # Test lookup on float key-value pair. l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype) self.assertAllClose(l2, v2) # Test erase and has key. self.assertAllEqual(map_ops.tensor_map_has_key(m, k), True) m = map_ops.tensor_map_erase(m, k, v.dtype) self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False) self.assertAllEqual(map_ops.tensor_map_has_key(m, k2), True)