Ejemplo n.º 1
0
 def test_cast(self):
     if legacy_onnx_pre_1_2() or legacy_opset_pre_6():
         test_cases = [("FLOAT", tf.float32), ("UINT8", tf.uint8),
                       ("INT8", tf.int8), ("UINT16", tf.uint16),
                       ("INT16", tf.int16), ("INT32", tf.int32),
                       ("INT64", tf.int64), ("BOOL", tf.bool),
                       ("FLOAT16", tf.float16), ("DOUBLE", tf.float64),
                       ("COMPLEX64", tf.complex64),
                       ("COMPLEX128", tf.complex128)]
     else:
         test_cases = [(TensorProto.FLOAT, tf.float32),
                       (TensorProto.UINT8, tf.uint8),
                       (TensorProto.INT8, tf.int8),
                       (TensorProto.UINT16, tf.uint16),
                       (TensorProto.INT16, tf.int16),
                       (TensorProto.INT32, tf.int32),
                       (TensorProto.INT64, tf.int64),
                       (TensorProto.BOOL, tf.bool),
                       (TensorProto.FLOAT16, tf.float16),
                       (TensorProto.DOUBLE, tf.float64),
                       (TensorProto.COMPLEX64, tf.complex64),
                       (TensorProto.COMPLEX128, tf.complex128)]
     for ty, tf_type in test_cases:
         node_def = helper.make_node("Cast", ["input"], ["output"], to=ty)
         vector = [2, 3]
         output = run_node(node_def, [vector])
         np.testing.assert_equal(output["output"].dtype, tf_type)
Ejemplo n.º 2
0
 def test_batch_normalization(self):
     if legacy_opset_pre_6():
         raise unittest.SkipTest("Backend doesn't support consumed flag")
     node_def = helper.make_node("BatchNormalization",
                                 ["X", "scale", "bias", "mean", "var"],
                                 ["Y"],
                                 epsilon=0.001)
     x_shape = [3, 5, 4, 2]
     param_shape = [5]
     _param_shape = [1, 5, 1, 1]
     x = self._get_rnd(x_shape, 0, 1)
     m = self._get_rnd(param_shape, 0, 1)
     _m = m.reshape(_param_shape)
     v = self._get_rnd(param_shape, 0, 1)
     _v = v.reshape(_param_shape)
     scale = self._get_rnd(param_shape, 0, 1)
     _scale = scale.reshape(_param_shape)
     bias = self._get_rnd(param_shape, 0, 1)
     _bias = bias.reshape(_param_shape)
     golden = self._batch_normalization(x, _m, _v, _bias, _scale, 0.001)
     output = run_node(node_def, [x, scale, bias, m, v])
     np.testing.assert_almost_equal(output["Y"], golden, decimal=5)
Ejemplo n.º 3
0
("test_sigmoid", tf.sigmoid, "Sigmoid", [get_rnd([10, 10])], {}),
("test_slice", tf.slice, "Slice", [get_rnd([5, 6, 7])], {"begin": [1, 0, 0], "size": [1, 1, 3]}),
("test_softmax", tf.nn.softmax, "Softmax", [get_rnd([10, 10])], {}),
("test_softplus", tf.nn.softplus, "Softplus", [get_rnd([10, 10])], {}),
("test_softsign", tf.nn.softsign, "Softsign", [get_rnd([10, 10])], {}),
("test_space_to_depth", tf.space_to_depth, "SpaceToDepth", [get_rnd([2, 8, 8, 5])], {"block_size": 2}),
("test_split", tf.split, "split", [get_rnd([10, 10]), [2, 3, 5]], {}),
("test_sqrt", tf.sqrt, "Sqrt", [get_rnd([10, 10])], {}),
("test_squeeze", tf.squeeze, "Squeeze", [get_rnd([1, 1, 10, 10])], {"axis":[0, 1]}),
("test_subtract", tf.subtract, "Sub", [get_rnd([10, 10]), get_rnd([10, 10])], {}),
("test_tanh", tf.tanh, "Tanh", [get_rnd([10, 10])], {}),
("test_top_k", tf.nn.top_k, "TopKV2", [get_rnd([10, 10, 10, 10])], {"k": 3}),
# Use reverse to test ignore_unimplemented
("test_unimplemented", tf.reverse, "ReverseV2", [get_rnd([1, 2, 3, 4]), [3]], {}, {"ignore_unimplemented": True}),
("test_unpack", tf.unstack, "unstack", [get_rnd([2, 3, 4])], {}),
("test_xor", tf.logical_xor, "LogicalXor", [get_rnd([10, 10], dtype=np.bool_), get_rnd([10, 10], dtype=np.bool_)], {}),
("test_transpose", tf.transpose, "transpose", [get_rnd([2, 10])], {"perm":[1, 0]}),
("test_concat", tf.concat, "concat", [[get_rnd([1, 10]),get_rnd([10, 10]),get_rnd([20, 10])], 0], {})
]

if not legacy_opset_pre_6():
  test_cases.append(("test_tile", tf.tile, "Tile", [get_rnd([1, 2, 3, 4]), np.random.randint(1, 10, (4,), dtype=np.int32)], {}))

for k, val in enumerate(test_cases):
  test_method = create_test(val)
  test_method.__name__ = str(val[0])
  setattr(TestNode, test_method.__name__, test_method)

if __name__ == '__main__':
  unittest.main()