def test_composite_clone_float32(self): def has_f16(comp): if any(v.type == float16 for v in comp.fgraph.variables): return True return False w = int8() x = float16() y = float32() cz = Composite([x, y], [tanh(x + cast(y, "float16"))]) c = Composite( [w, x, y], [ cz(x, y) - cz(x, y)**2 + cast(x, "int16") + cast(x, "float32") + cast(w, "float16") - constant(np.float16(1.0)) ], ) assert has_f16(c) nc = c.clone_float32() assert not has_f16(nc) v = uint8() w = float16() x = float16() y = float16() z = float16() c = Composite([v, w, x, y, z], [switch(v, mul(w, x, y), z)]) assert has_f16(c) nc = c.clone_float32() assert not has_f16(nc)
def grad(self, inp, grads): (x, ) = inp (gz, ) = grads res = true_div(-1.0, expm1(-x)) # Correct gradient at 0.0 to be -inf res = switch(isinf(res), -np.inf, res) return [gz * res]