Esempio n. 1
0
def test_cases_with_decomposition():
    return parameterized.named_parameters(
        {
            'testcase_name': 'batch1',
            'batch_size': 1,
            'use_decomposition': False
        }, {
            'testcase_name': 'batch4',
            'batch_size': 4,
            'use_decomposition': True
        })
Esempio n. 2
0
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_batch1_contextdim10',
            'batch_size': 1,
            'context_dim': 10,
        }, {
            'testcase_name': '_batch4_contextdim5',
            'batch_size': 4,
            'context_dim': 5,
        })
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_batch1_numtrainsteps0',
            'batch_size': 1,
            'actions_from_reward_layer': False,
        }, {
            'testcase_name': '_batch4_numtrainsteps10',
            'batch_size': 4,
            'actions_from_reward_layer': True,
        })
Esempio n. 4
0
def genNamedParametersNArgs(n, rng):
    return parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": jtu.format_test_name_suffix(
                    "", shapes, dtypes),
                "rng": rng,
                "shapes": shapes,
                "dtypes": dtypes
            } for shapes in CombosWithReplacement(all_shapes, n)
            for dtypes in CombosWithReplacement(float_dtypes, n)))
Esempio n. 5
0
def check_default_schedules(cond, fun):
    schedules = [('seq', [('sequential', None)]),
                 ('vec', [('vectorized', None)]),
                 ('par', [('parallel', None)]),
                 ('lim_vmap', [('sequential', None), ('vectorized', 2)]),
                 ('soft_pmap', [('parallel', 2), ('vectorized', None)])]
    schedules = [s for s in schedules if cond(s[1])]
    return parameterized.named_parameters({
        "testcase_name": "_" + name,
        "schedule": schedule
    } for name, schedule in schedules)(fun)
Esempio n. 6
0
def genNamedParametersNArgs(n):
    return parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": jtu.format_test_name_suffix(
                    "", shapes, dtypes),
                "shapes": shapes,
                "dtypes": dtypes
            } for shapes in itertools.combinations_with_replacement(
                all_shapes, n)
            for dtypes in itertools.combinations_with_replacement(
                jtu.dtypes.floating, n)))
Esempio n. 7
0
 def decorator(fn, *named_executors):
   """Construct a custom `parameterized.named_parameter` decorator for `fn`."""
   if not named_executors:
     named_executors = [
         ('reference', reference_executor.ReferenceExecutor(compiler=None)),
         ('local', executor_stacks.local_executor_factory()),
     ]
   named_parameters_decorator = parameterized.named_parameters(
       *named_executors)
   fn = executor_decorator(fn)
   fn = named_parameters_decorator(fn)
   return fn
def test_cases():
    return parameterized.named_parameters(
        dict(testcase_name='_observation_[5]_action_[3]_batch_1',
             observation_shape=[5],
             action_shape=[3],
             batch_size=1,
             seed=12345),
        dict(testcase_name='_observation_[3]_action_[5]_batch_2',
             observation_shape=[3],
             action_shape=[5],
             batch_size=2,
             seed=98765),
    )
Esempio n. 9
0
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_batch1_contextdim10_float32',
            'batch_size': 1,
            'context_dim': 10,
            'dtype': tf.float32,
        }, {
            'testcase_name': '_batch4_contextdim5_float64',
            'batch_size': 4,
            'context_dim': 5,
            'dtype': tf.float64,
        })
Esempio n. 10
0
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_batch1_contextdim10_numagents2',
            'batch_size': 1,
            'context_dim': 10,
            'num_agents': 2,
        }, {
            'testcase_name': '_batch4_contextdim5_numagents10',
            'batch_size': 4,
            'context_dim': 5,
            'num_agents': 10,
        })
Esempio n. 11
0
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': 'batch1UCB',
            'batch_size': 1,
            'exploration_strategy':
            linear_policy.ExplorationStrategy.optimistic,
        }, {
            'testcase_name': 'batch4UCB',
            'batch_size': 4,
            'exploration_strategy':
            linear_policy.ExplorationStrategy.optimistic,
        })
Esempio n. 12
0
def RecurrentTestParameters(test_fn):
    use_tf_function = py_utils._UseTfFunction()

    def WrappedTestFn(self):
        # TODO(laigd): remove this check when 312743821 and 313682500 are in the
        # release.
        if use_tf_function and tf.compat.v1.__version__ < '2.3.0':
            return
        test_fn(self)

    decorator = parameterized.named_parameters((
        '_function', ) if use_tf_function else ('_defun', ))
    return decorator(WrappedTestFn)
Esempio n. 13
0
def test_cases():
  return parameterized.named_parameters(
      {
          'testcase_name':
              '_batch1_contextdim10_float32',
          'batch_size':
              1,
          'context_dim':
              10,
          'exploration_policy':
              linear_agent.ExplorationPolicy.linear_ucb_policy,
          'dtype':
              tf.float32,
      }, {
          'testcase_name':
              '_batch4_contextdim5_float64_UCB',
          'batch_size':
              4,
          'context_dim':
              5,
          'exploration_policy':
              linear_agent.ExplorationPolicy.linear_ucb_policy,
          'dtype':
              tf.float64,
      }, {
          'testcase_name':
              '_batch4_contextdim5_float64_TS',
          'batch_size':
              4,
          'context_dim':
              5,
          'exploration_policy':
              linear_agent.ExplorationPolicy.linear_thompson_sampling_policy,
          'dtype':
              tf.float64,
      }, {
          'testcase_name':
              '_batch4_contextdim5_float64_decomp',
          'batch_size':
              4,
          'context_dim':
              5,
          'exploration_policy':
              linear_agent.ExplorationPolicy.linear_ucb_policy,
          'dtype':
              tf.float64,
          'use_eigendecomp':
              True,
      })
Esempio n. 14
0
def cg_test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_n_1',
            'n': 1,
            'rhs': 1,
        }, {
            'testcase_name': '_n_10',
            'n': 10,
            'rhs': 1,
        }, {
            'testcase_name': '_n_100',
            'n': 100,
            'rhs': 5,
        })
Esempio n. 15
0
 def parameterize_by_sampler(extra, f, subset):
   if extra is None:
     extra = [("", {})]
   else:
     extra = list(extra)
   subset_fn = jtu.cases_from_list if subset else lambda x: x
   return parameterized.named_parameters(subset_fn(
       {"testcase_name": name + extra_name, "distr_sample": sample, **extra_kwargs}
       for name, sample in [
         ("Uniform", jax.random.uniform),
         ("Normal", jax.random.normal),
         ("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
         ("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
       ]
       for extra_name, extra_kwargs in extra))(f)
Esempio n. 16
0
def _enhance_named_parameters(factories, testcases):
  """Calls parameterized.named_parameters() with enhanced testcases."""
  if not testcases:
    testcases = [("variant",)]
  enhanced_testcases = []
  for testcase in testcases:
    name = testcase[0]
    test_args = tuple(testcase[1:])
    for variant_name, raw_factory in factories.items():
      variant_factory = _produce_variant_factory(raw_factory)
      # The variant_factory will be the last argument.
      case = (name + "_" + variant_name,) + test_args + (variant_factory,)
      enhanced_testcases.append(case)
  return parameterized.named_parameters(
      *enhanced_testcases)
Esempio n. 17
0
def RecurrentTestParameters(test_fn):

  def WrappedTestFn(self, use_tf_function):
    # TODO(laigd): remove this check when 312743821 and 313682500 are in the
    # release.
    if use_tf_function and tf.compat.v1.__version__ < '2.3.0':
      return
    FLAGS.if_use_tf_function = use_tf_function
    FLAGS.while_loop_use_tf_function = use_tf_function
    FLAGS.call_defun_use_tf_function = use_tf_function
    test_fn(self)

  decorator = parameterized.named_parameters(
      ('_defun', False),
      ('_function', True),
  )
  return decorator(WrappedTestFn)
Esempio n. 18
0
    def decorator(test_method_or_class):
        """The decorator to be returned."""

        # Generate good test names that can be used with --test_filter.
        named_combinations = []
        for combination in combinations:
            # We use OrderedDicts in `combine()` and `times()` to ensure stable
            # order of keys in each dictionary.
            assert isinstance(combination, OrderedDict)
            name = "".join([
                "_{}_{}".format(
                    "".join(filter(str.isalnum, key)),
                    "".join(filter(str.isalnum, _get_name(value, i))))
                for i, (key, value) in enumerate(combination.items())
            ])
            named_combinations.append(
                OrderedDict(
                    list(combination.items()) +
                    [("testcase_name", "_test{}".format(name))]))

        if isinstance(test_method_or_class, type):
            class_object = test_method_or_class
            class_object._test_method_ids = test_method_ids = {}
            for name, test_method in six.iteritems(
                    class_object.__dict__.copy()):
                if (name.startswith(unittest.TestLoader.testMethodPrefix)
                        and isinstance(test_method, types.FunctionType)):
                    delattr(class_object, name)
                    methods = {}
                    parameterized._update_class_dict_for_param_test_case(
                        class_object.__name__, methods, test_method_ids, name,
                        parameterized._ParameterizedTestIter(
                            _augment_with_special_arguments(
                                test_method,
                                test_combinations=test_combinations),
                            named_combinations, parameterized._NAMED, name))
                    for method_name, method in six.iteritems(methods):
                        setattr(class_object, method_name, method)

            return class_object
        else:
            test_method = _augment_with_special_arguments(
                test_method_or_class, test_combinations=test_combinations)
            return parameterized.named_parameters(
                *named_combinations)(test_method)
Esempio n. 19
0
  def decorator(test_method):
    # If decorating result of another dict_decorator
    if isinstance(test_method, abc.Iterable):
      actual_tests = []
      for old_test in test_method.testcases:
        # each test is a ('test_suffix', dict) tuple
        new_dict = old_test[1].copy()
        new_dict[key] = value
        test_suffix = '%s_%s_%s' % (old_test[0], key, value)
        actual_tests.append((test_suffix, new_dict))

      test_method.testcases = actual_tests
      return test_method
    else:
      test_suffix = ('_%s_%s') % (key, value)
      tests_to_make = ((test_suffix, {key: value}),)
      # 'test_method' here is the original test method
      return parameterized.named_parameters(*tests_to_make)(test_method)
class StringArrayTest(test.TestCase, parameterized.TestCase):

  StringParameters = parameterized.named_parameters(  # pylint: disable=invalid-name
      # Tensorflow always encodes python string into bytes, regardless of
      # requested dtype.
      ('str_u8', 'abcde\U0001f005', 'U8', b'abcde\xf0\x9f\x80\x85'),
      ('str_s8', 'abcde\U0001f005', 'S8', b'abcde\xf0\x9f\x80\x85'),
      ('str_none', 'abcde\U0001f005', None, b'abcde\xf0\x9f\x80\x85'),
      ('zstr_u8', '\0abcde\U0001f005', 'U8', b'\0abcde\xf0\x9f\x80\x85'),
      ('zstr_s8', '\0abcde\U0001f005', 'S8', b'\0abcde\xf0\x9f\x80\x85'),
      ('zstr_none', '\0abcde\U0001f005', None, b'\0abcde\xf0\x9f\x80\x85'),
      ('bytes_u8', b'abcdef', 'U8', b'abcdef'),
      ('bytes_s8', b'abcdef', 'S8', b'abcdef'),
      ('bytes_none', b'abcdef', None, b'abcdef'),
      ('zbytes_u8', b'\0abcdef', 'U8', b'\0abcdef'),
      ('zbytes_s8', b'\0abcdef', 'S8', b'\0abcdef'),
      ('zbytes_none', b'\0abcdef', None, b'\0abcdef'),
  )

  @StringParameters
  def testArray(self, a, dtype, a_as_bytes):
    b = np_array_ops.array(a, dtype=dtype)
    self.assertIsInstance(b.numpy(), bytes)
    self.assertEqual(b.numpy(), a_as_bytes)

  @StringParameters
  def testAsArray(self, a, dtype, a_as_bytes):
    b = np_array_ops.asarray(a, dtype=dtype)
    self.assertIsInstance(b.numpy(), bytes)
    self.assertEqual(b.numpy(), a_as_bytes)

  @StringParameters
  def testZerosLike(self, a, dtype, unused_a_as_bytes):
    b = np_array_ops.zeros_like(a, dtype=dtype)
    self.assertIsInstance(b.numpy(), bytes)
    self.assertEqual(b.numpy(), b'')

  @StringParameters
  def testEmptyLike(self, a, dtype, unused_a_as_bytes):
    b = np_array_ops.empty_like(a, dtype=dtype)
    self.assertIsInstance(b.numpy(), bytes)
    self.assertEqual(b.numpy(), b'')
Esempio n. 21
0
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_batch1_contextdim10_numagents2_info',
            'batch_size': 1,
            'context_dim': 10,
            'num_agents': 2,
            'emit_policy_info': True
        }, {
            'testcase_name': '_batch3_contextdim7_numagents17_noinfo',
            'batch_size': 3,
            'context_dim': 7,
            'num_agents': 17,
            'emit_policy_info': False
        }, {
            'testcase_name': '_batch4_contextdim5_numagents10_info',
            'batch_size': 4,
            'context_dim': 5,
            'num_agents': 10,
            'emit_policy_info': True
        })
Esempio n. 22
0
  def decorator(test_method_or_class):
    """The decorator to be returned."""

    # Generate good test names that can be used with --test_filter.
    named_combinations = []
    for combination in combinations:
      # We use OrderedDicts in `combine()` and `times()` to ensure stable
      # order of keys in each dictionary.
      assert isinstance(combination, OrderedDict)
      name = "".join([
          "_{}_{}".format(
              "".join(filter(str.isalnum, key)),
              "".join(filter(str.isalnum, str(value))))
          for key, value in combination.items()
      ])
      named_combinations.append(
          OrderedDict(
              list(combination.items()) + [("testcase_name",
                                            "_test{}".format(name))]))

    if isinstance(test_method_or_class, type):
      class_object = test_method_or_class
      class_object._test_method_ids = test_method_ids = {}
      for name, test_method in six.iteritems(class_object.__dict__.copy()):
        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
            isinstance(test_method, types.FunctionType)):
          delattr(class_object, name)
          methods = {}
          parameterized._update_class_dict_for_param_test_case(
              class_object.__name__, methods, test_method_ids, name,
              parameterized._ParameterizedTestIter(
                  _augment_with_special_arguments(test_method),
                  named_combinations, parameterized._NAMED, name))
          for method_name, method in six.iteritems(methods):
            setattr(class_object, method_name, method)

      return class_object
    else:
      test_method = _augment_with_special_arguments(test_method_or_class)
      return parameterized.named_parameters(*named_combinations)(test_method)
Esempio n. 23
0
VJP_SAMPLE_BLOCKLIST = ()
VJP_LOGPROB_SAMPLE_BLOCKLIST = ()
VJP_LOGPROB_PARAM_BLOCKLIST = (
    'VonMisesFisher',  # http://b/171079052
)

PYTREE_BLOCKLIST = (
    'Bates',
    'TransformedDistribution',
)

DEFAULT_MAX_EXAMPLES = 3

test_all_distributions = parameterized.named_parameters({
    'testcase_name': dname,
    'dist_name': dname
} for dname in sorted(
    list(dhps.INSTANTIABLE_BASE_DISTS.keys()) +
    list(d for d in dhps.INSTANTIABLE_META_DISTS if d != 'Mixture')))

test_base_distributions = parameterized.named_parameters({
    'testcase_name': dname,
    'dist_name': dname
} for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys())))


class JitTest(test_util.TestCase):
    @test_all_distributions
    @hp.given(hps.data())
    @tfp_hps.tfp_hp_settings(default_max_examples=DEFAULT_MAX_EXAMPLES)
    def testSample(self, dist_name, data):
        if (dist_name in JIT_SAMPLE_BLOCKLIST) != FLAGS.blocklists_only:
from absl.testing import parameterized
import numpy as np
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import

from tf_agents.bandits.networks import global_and_arm_feature_network as gafn
from tf_agents.bandits.specs import utils as bandit_spec_utils
from tf_agents.specs import tensor_spec
from tf_agents.utils import test_utils

parameters = parameterized.named_parameters(
    {
        'testcase_name': 'batch2feat4act3',
        'batch_size': 2,
        'feature_dim': 4,
        'num_actions': 3
    }, {
        'testcase_name': 'batch1feat7act9',
        'batch_size': 1,
        'feature_dim': 7,
        'num_actions': 9
    })


class GlobalAndArmFeatureNetworkTest(parameterized.TestCase,
                                     test_utils.TestCase):
    @parameters
    def testCreateFeedForwardCommonTowerNetwork(self, batch_size, feature_dim,
                                                num_actions):
        obs_spec = bandit_spec_utils.create_per_arm_observation_spec(
            7, feature_dim, num_actions)
        net = gafn.create_feed_forward_common_tower_network(
Esempio n. 25
0
# pyformat: disable
_FILE_SPEC_VALUES = (('BytesIO', BytesIOSpec),
                     ('FileIO', FileIOSpec),
                     ('BufferedIO', BufferedIOSpec),
                     ('BuiltinFile', BuiltinFileSpec),
                     ('TensorFlowGFile', TensorFlowGFileSpec))

_RANDOM_ACCESS_VALUES = (('randomAccess', True),
                         ('streamAccess', False))

_PARALLELISM_VALUES = (('serial', 0),
                       ('parallel', 10))
# pyformat: enable

_PARAMETERIZE_BY_RANDOM_ACCESS = (
    parameterized.named_parameters(*_RANDOM_ACCESS_VALUES))

_PARAMETERIZE_BY_RANDOM_ACCESS_AND_PARALLELISM = (
    parameterized.named_parameters(
        combine_named_parameters(_RANDOM_ACCESS_VALUES, _PARALLELISM_VALUES)))

_PARAMETERIZE_BY_FILE_SPEC_AND_RANDOM_ACCESS_AND_PARALLELISM = (
    parameterized.named_parameters(
        combine_named_parameters(_FILE_SPEC_VALUES, _RANDOM_ACCESS_VALUES,
                                 _PARALLELISM_VALUES)))


class RecordsTest(parameterized.TestCase):

  def corrupt_at(self, files, index):
    byte_reader = files.reading_open()
Esempio n. 26
0
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test


class PathLike(object):
    """Backport of pathlib.Path for Python < 3.6"""
    def __init__(self, name):
        self.name = name

    def __fspath__(self):
        return self.name


run_all_path_types = parameterized.named_parameters(
    ("str", os.path.join),
    ("pathlike", lambda *paths: PathLike(os.path.join(*paths))))


class FileIoTest(test.TestCase, parameterized.TestCase):
    def setUp(self):
        self._base_dir = os.path.join(self.get_temp_dir(), "base_dir")
        file_io.create_dir(self._base_dir)

    def tearDown(self):
        file_io.delete_recursively(self._base_dir)

    def testEmptyFilename(self):
        f = file_io.FileIO("", mode="r")
        with self.assertRaises(errors.NotFoundError):
            _ = f.read()
Esempio n. 27
0
def test_cases():
    return parameterized.named_parameters(
        {
            'testcase_name': '_MCSoftmaxDense_logit_noise_normal_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'MCSoftmaxDense'
        },
        {
            'testcase_name': '_MCSoftmaxDense_logit_noise_logistic_10',
            'logit_noise': tfp.distributions.Logistic,
            'num_classes': 10,
            'model_type': 'MCSoftmaxDense'
        },
        {
            'testcase_name': '_MCSoftmaxDense_logit_noise_gumbel_10',
            'logit_noise': tfp.distributions.Gumbel,
            'num_classes': 10,
            'model_type': 'MCSoftmaxDense'
        },
        {
            'testcase_name': '_MCSoftmaxDenseFA_logit_noise_normal_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'MCSoftmaxDenseFA'
        },
        {
            'testcase_name': '_MCSigmoidDenseFA_logit_noise_normal_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'MCSigmoidDenseFA'
        },
        {
            'testcase_name': '_MCSoftmaxDenseFAPE_logit_noise_normal_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'MCSoftmaxDenseFAPE'
        },
        {
            'testcase_name': '_MCSigmoidDenseFAPE_logit_noise_normal_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'MCSigmoidDenseFAPE'
        },
        {
            'testcase_name': '_MCSoftmaxDense_logit_noise_normal_2',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 2,
            'model_type': 'MCSoftmaxDense'
        },
        {
            'testcase_name': '_MCSoftmaxDense_logit_noise_logistic_2',
            'logit_noise': tfp.distributions.Logistic,
            'num_classes': 2,
            'model_type': 'MCSoftmaxDense'
        },
        {
            'testcase_name': '_MCSoftmaxDense_logit_noise_gumbel_2',
            'logit_noise': tfp.distributions.Gumbel,
            'num_classes': 2,
            'model_type': 'MCSoftmaxDense'
        },
        {
            'testcase_name': '_Exact_logit_noise_normal_2',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 2,
            'model_type': 'Exact'
        },
        {
            'testcase_name': '_Exact_logit_noise_logistic_2',
            'logit_noise': tfp.distributions.Logistic,
            'num_classes': 2,
            'model_type': 'Exact'
        },
        {
            'testcase_name': '_EnsembleGibbsCE_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'EnsembleGibbsCE'
        },
        {
            'testcase_name': '_EnsembleGibbsCE_2',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 2,
            'model_type': 'EnsembleGibbsCE'
        },
        {
            'testcase_name': '_EnsembleEnsembleCE_10',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 10,
            'model_type': 'EnsembleEnsembleCE'
        },
        {
            'testcase_name': '_EnsembleEnsembleCE_2',
            'logit_noise': tfp.distributions.Normal,
            'num_classes': 2,
            'model_type': 'EnsembleEnsembleCE'
        },
    )
Esempio n. 28
0
parameterized_all_layers = parameterized.named_parameters(
    ("QuantDense", lq.layers.QuantDense, tf.keras.layers.Dense, (3, 2), dict(units=3)),
    (
        "QuantConv1D",
        lq.layers.QuantConv1D,
        tf.keras.layers.Conv1D,
        (2, 3, 7),
        dict(filters=2, kernel_size=3),
    ),
    (
        "QuantConv2D",
        lq.layers.QuantConv2D,
        tf.keras.layers.Conv2D,
        (2, 3, 7, 6),
        dict(filters=2, kernel_size=3),
    ),
    (
        "QuantConv3D",
        lq.layers.QuantConv3D,
        tf.keras.layers.Conv3D,
        (2, 3, 7, 6, 5),
        dict(filters=2, kernel_size=3),
    ),
    (
        "QuantConv2DTranspose",
        lq.layers.QuantConv2DTranspose,
        tf.keras.layers.Conv2DTranspose,
        (2, 3, 7, 6),
        dict(filters=2, kernel_size=3),
    ),
    (
        "QuantConv3DTranspose",
        lq.layers.QuantConv3DTranspose,
        tf.keras.layers.Conv3DTranspose,
        (2, 3, 7, 6, 5),
        dict(filters=2, kernel_size=3),
    ),
    (
        "QuantLocallyConnected1D",
        lq.layers.QuantLocallyConnected1D,
        tf.keras.layers.LocallyConnected1D,
        (2, 8, 5),
        dict(filters=4, kernel_size=3),
    ),
    (
        "QuantLocallyConnected2D",
        lq.layers.QuantLocallyConnected2D,
        tf.keras.layers.LocallyConnected2D,
        (8, 6, 10, 4),
        dict(filters=3, kernel_size=3),
    ),
)
Esempio n. 29
0
# pyformat: disable
_FILE_SPEC_VALUES = (('BytesIO', BytesIOSpec), ('FileIO', FileIOSpec),
                     ('BufferedIO', BufferedIOSpec),
                     ('BuiltinFile', BuiltinFileSpec), ('TensorFlowGFile',
                                                        TensorFlowGFileSpec))

_RANDOM_ACCESS_VALUES = (('randomAccess', RandomAccess.RANDOM_ACCESS),
                         ('sequentialAccessDetected',
                          RandomAccess.SEQUENTIAL_ACCESS_DETECTED),
                         ('sequentialAccessExplicit',
                          RandomAccess.SEQUENTIAL_ACCESS_EXPLICIT))

_PARALLELISM_VALUES = (('serial', 0), ('parallel', 10))
# pyformat: enable

_PARAMETERIZE_BY_FILE_SPEC = (parameterized.named_parameters(
    *_FILE_SPEC_VALUES))

_PARAMETERIZE_BY_RANDOM_ACCESS = (parameterized.named_parameters(
    *_RANDOM_ACCESS_VALUES))

_PARAMETERIZE_BY_RANDOM_ACCESS_AND_PARALLELISM = (
    parameterized.named_parameters(
        combine_named_parameters(_RANDOM_ACCESS_VALUES, _PARALLELISM_VALUES)))

_PARAMETERIZE_BY_FILE_SPEC_AND_PARALLELISM = (parameterized.named_parameters(
    combine_named_parameters(_FILE_SPEC_VALUES, _PARALLELISM_VALUES)))

_PARAMETERIZE_BY_FILE_SPEC_AND_RANDOM_ACCESS_AND_PARALLELISM = (
    parameterized.named_parameters(
        combine_named_parameters(_FILE_SPEC_VALUES, _RANDOM_ACCESS_VALUES,
                                 _PARALLELISM_VALUES)))
Esempio n. 30
0
    import tensorflow as tf  # type: ignore[import]
except ImportError:
    tf = None

config.parse_flags_with_absl()


def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
    if with_jit:
        return jax.jit(func)
    else:
        return func


parameterized_jit = parameterized.named_parameters(
    dict(testcase_name="_jit" if with_jit else "", with_jit=with_jit)
    for with_jit in [True, False])


class CallTfTest(jtu.JaxTestCase):
    def setUp(self):
        if tf is None:
            raise unittest.SkipTest("Test requires tensorflow")
        # TODO(b/171320191): this line works around a missing context initialization
        # bug in TensorFlow.
        _ = tf.add(1, 1)
        super().setUp()

    @parameterized_jit
    def test_eval_scalar_arg(self, with_jit=False):
        x = 3.
# pylint: disable=no-name-in-module

from tensorflow_probability.python.distributions._jax import hypothesis_testlib as dhps
from tensorflow_probability.python.experimental.substrates.jax import tf2jax as tf
from tensorflow_probability.python.internal._jax import hypothesis_testlib as tfp_hps
from tensorflow_probability.python.internal._jax import test_util

JIT_SAMPLE_BLACKLIST = set()
JIT_LOGPROB_BLACKLIST = set()

VMAP_SAMPLE_BLACKLIST = set()
VMAP_LOGPROB_BLACKLIST = set()

test_all_distributions = parameterized.named_parameters(
    {'testcase_name': dname, 'dist_name': dname} for dname in
    sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys())))


class JitTest(test_util.TestCase):

  @test_all_distributions
  @hp.given(hps.data())
  @tfp_hps.tfp_hp_settings()
  def testSample(self, dist_name, data):
    if dist_name in JIT_SAMPLE_BLACKLIST:
      self.skipTest('Distribution currently broken.')
    dist = data.draw(dhps.distributions(enable_vars=False,
                                        dist_name=dist_name))
    def _sample(seed):
      return dist.sample(seed=seed)