コード例 #1
0
ファイル: tf_test_util.py プロジェクト: GregCT/jax
 def polymorphic_shape_to_tensorspec(poly_shape: str) -> tf.TensorSpec:
   in_spec = masking.parse_spec(poly_shape)
   return tf.TensorSpec(
       tuple(
           int(dim_spec) if dim_spec.is_constant else None
           for dim_spec in in_spec),
       dtype=tf.float32)
コード例 #2
0
 def test_parse_spec(self, spec, ans):
   self.assertEqual(str(parse_spec(spec)), ans)
   self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans)
コード例 #3
0
 def test_shape_parsing(self, spec, ans):
     self.assertEqual(str(parse_spec(spec)), ans)
コード例 #4
0
 def shaped_array(shape):
     if isinstance(shape, str):
         return core.ShapedArray(masking.parse_spec(shape), np.float32)
     else:
         return core.ShapedArray(shape, np.float32)
コード例 #5
0
 def solve_shape_vars(shape_spec: str,
                      shape: Sequence[int]) -> Dict[str, int]:
     shape_polys = masking.parse_spec(shape_spec)
     return jax2tf.jax2tf._solve_shape_vars(
         util.safe_zip(shape_polys, shape))