def _divide_to_multiply_block(block): for op in list(block.operations): for b in op.blocks: _divide_to_multiply_block(b) if len(op.blocks) > 0: # This op can't be divide. continue # If real_div has integer input, the result is an integer (following TensorFlow spec). # Hence this pass needs disabled if the input is not float, since it translates y # to a floating point number. If x or y was originally an integer, and y becomes # a floating point number, then the original type # signature (with integer output) would not be preserved. if op.op_type == "real_div" and op.y.val is not None and _types.is_float( op.x.dtype): with block: new_y_val = np.array(1.0, dtype=op.y.val.dtype) / op.y.val if not np.isfinite(new_y_val).all(): continue x = mb.mul(x=op.x, y=new_y_val, name="_inversed_" + op.name, before_op=op) op.enclosing_block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=x) block.remove_ops([op])
def _tensor_field_by_type(tensor_val, builtin_type): if builtin_type == types.bool: return tensor_val.bools.values elif types.is_int(builtin_type): if (builtin_type == types.int64 or builtin_type == types.uint64): return tensor_val.longInts.values if builtin_type in (types.int8, types.uint8, types.uint32): return tensor_val.bytes.values return tensor_val.ints.values elif types.is_float(builtin_type): if (builtin_type == types.fp64): return tensor_val.doubles.values elif (builtin_type == types.fp32): return tensor_val.floats.values elif (builtin_type == types.fp16): return tensor_val.bytes.values else: raise TypeError( "Unsupported float dtype for MIL proto serialization: {}". format(builtin_to_string(builtin_type))) elif builtin_type == types.str: return tensor_val.strings.values else: raise NotImplementedError("Unimplemented tensor type for: " + str(builtin_type))
def test_builder_real_div_both_ints(self): x = np.array([5], dtype=np.int32) y = np.array([2], dtype=np.int32) expected_outputs = np.array([2.5], dtype=np.float32) v = mb.real_div(x=x, y=y) np.testing.assert_allclose(expected_outputs, v.val, atol=1e-04, rtol=1e-05) # real_div should produce float values regardless of input type assert isinstance(v.val[0], (float, np.float32)) # make sure the dtype is float assert types.is_float(v.dtype) # make sure the symbolic type matches the value type assert v._sym_type.get_primitive() == v._sym_val.get_primitive()