def test_const_floats(self, tf_dtype, np_dtype): shape = [1, 1, 50, 50] values = np.random.choice(a=[True, False], size=shape, p=[0.5, 0.5]) tensor_proto = tf.make_tensor_proto(values=values, dtype=tf_dtype, shape=shape) pb = PB({ "attr": PB({ "value": PB({ "tensor": PB({ "dtype": tensor_proto.dtype, "tensor_shape": tensor_proto.tensor_shape, "bool_val": values.tolist() }) }) }) }) self.expected = { 'data_type': np_dtype, 'shape': np.asarray(shape, dtype=np.int), 'value': values } self.res = tf_const_ext(pb=pb) self.res["infer"](None) self.call_args = self.infer_mock.call_args self.expected_call_args = None self.compare()
def test_const_uints(self, tf_dtype, np_dtype): shape = [1, 1, 200, 50] values = np.random.randint(low=np.iinfo(np_dtype).min, high=np.iinfo(np_dtype).max, size=shape, dtype=np_dtype) tensor_proto = tf.make_tensor_proto(values=values, dtype=tf_dtype, shape=shape) pb = PB({"attr": PB({ "value": PB({ "tensor": PB({ "dtype": tensor_proto.dtype, "tensor_shape": tensor_proto.tensor_shape, }) }) })}) if tf_dtype == tf.uint16: setattr(pb.attr.value.tensor, "int_val", values.tolist()) else: setattr(pb.attr.value.tensor, "tensor_content", tensor_proto.tensor_content) self.expected = { 'data_type': np_dtype, 'shape': np.asarray(shape, dtype=np.int), 'value': values } self.res = tf_const_ext(pb=pb) self.res["infer"](None) self.call_args = self.infer_mock.call_args self.expected_call_args = None self.compare()
def test_const_floats(self, tf_dtype, np_dtype): shape = [1, 1, 200, 50] values = np.random.uniform(low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, size=shape).astype(np_dtype) tensor_proto = tf.make_tensor_proto(values=values, dtype=tf_dtype, shape=shape) pb = PB({ "attr": PB({ "value": PB({ "tensor": PB({ "dtype": tensor_proto.dtype, "tensor_shape": tensor_proto.tensor_shape, "tensor_content": tensor_proto.tensor_content }) }) }) }) self.expected = { 'data_type': np_dtype, 'shape': np.asarray(shape, dtype=np.int), 'value': values } self.res = tf_const_ext(pb=pb) self.res["infer"](None) self.call_args = self.infer_mock.call_args self.expected_call_args = None self.compare()