def _replace_event_shape_in_tensorshape(input_tensorshape, event_shape_in, event_shape_out): """Replaces the event shape dims of a `TensorShape`. Args: input_tensorshape: a `TensorShape` instance in which to attempt replacing event shape. event_shape_in: `Tensor` shape representing the event shape expected to be present in (rightmost dims of) `tensorshape_in`. Must be compatible with the rightmost dims of `tensorshape_in`. event_shape_out: `Tensor` shape representing the new event shape, i.e., the replacement of `event_shape_in`, Returns: output_tensorshape: `TensorShape` with the rightmost `event_shape_in` replaced by `event_shape_out`. Might be partially defined, i.e., `TensorShape(None)`. is_validated: Python `bool` indicating static validation happened. Raises: ValueError: if we can determine the event shape portion of `tensorshape_in` as well as `event_shape_in` both statically, and they are not compatible. "Compatible" here means that they are identical on any dims that are not -1 in `event_shape_in`. """ event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape) if tensorshape_util.rank( input_tensorshape) is None or event_shape_in_ndims is None: return tf.TensorShape(None), False # Not is_validated. input_non_event_ndims = tensorshape_util.rank( input_tensorshape) - event_shape_in_ndims if input_non_event_ndims < 0: raise ValueError( 'Input has lower rank ({}) than `event_shape_ndims` ({}).'.format( tensorshape_util.rank(input_tensorshape), event_shape_in_ndims)) input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims] input_event_tensorshape = input_tensorshape[input_non_event_ndims:] # Check that `input_event_shape_` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. event_shape_in_ = tf.get_static_value(event_shape_in) is_validated = (tensorshape_util.is_fully_defined(input_event_tensorshape) and event_shape_in_ is not None) if is_validated: input_event_shape_ = np.int32(input_event_tensorshape) mask = event_shape_in_ >= 0 explicit_input_event_shape_ = input_event_shape_[mask] explicit_event_shape_in_ = event_shape_in_[mask] if not np.all(explicit_input_event_shape_ == explicit_event_shape_in_): raise ValueError( 'Input `event_shape` does not match `event_shape_in` ' '({} vs {}).'.format(input_event_shape_, event_shape_in_)) event_tensorshape_out = tensorshape_util.constant_value_as_shape( event_shape_out) if tensorshape_util.rank(event_tensorshape_out) is None: output_tensorshape = tf.TensorShape(None) else: output_tensorshape = tensorshape_util.concatenate( input_non_event_tensorshape, event_tensorshape_out) return output_tensorshape, is_validated
def _convert_shape(input_shape): input_shape = tf.TensorShape(input_shape) if to_tuples: input_shape = tuple(input_shape.as_list()) return input_shape
def compute_output_shape(self, input_shape): return tf.TensorShape(self.get_output_shape(input_shape))
def _batch_shape(self): scalar_shape = tf.TensorShape([]) return tf.broadcast_static_shape( super(GeneralizedMatern, self)._batch_shape(), scalar_shape if self.df is None else self.df.shape)
def _batch_shape(self): if tensorshape_util.rank(self.samples.shape) is None: return tf.TensorShape(None) return self.samples.shape[:self._samples_axis]
def testMatrixEvent(self): batch_shape = [2] event_shape = [2, 3, 3] batch_shape_pl = tf1.placeholder_with_default( input=np.int32(batch_shape), shape=None, name='dynamic_batch_shape') event_shape_pl = tf1.placeholder_with_default( input=np.int32(event_shape), shape=None, name='dynamic_event_shape') scale = 2. loc = 0. fake_mvn_dynamic = self._cls()( distribution=tfd.Normal(loc=loc, scale=scale), bijector=DummyMatrixTransform(), batch_shape=batch_shape_pl, event_shape=event_shape_pl, validate_args=True) fake_mvn_static = self._cls()( distribution=tfd.Normal(loc=loc, scale=scale), bijector=DummyMatrixTransform(), batch_shape=batch_shape, event_shape=event_shape, validate_args=True) def actual_mvn_log_prob(x): # This distribution is the normal PDF, reduced over the # last 3 dimensions + a jacobian term which corresponds # to the determinant of x. return (np.sum(stats.norm(loc, scale).logpdf(x), axis=(-1, -2, -3)) + np.sum(np.linalg.det(x), axis=-1)) self.assertAllEqual([2, 3, 3], fake_mvn_static.event_shape) self.assertAllEqual([2], fake_mvn_static.batch_shape) if not tf.executing_eagerly(): self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.event_shape) self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.batch_shape) num_samples = 5e3 for fake_mvn in [fake_mvn_static, fake_mvn_dynamic]: # Ensure sample works by checking first, second moments. y = fake_mvn.sample(int(num_samples), seed=tfp_test_util.test_seed()) x = y[0:5, ...] [ x_, fake_event_shape_, fake_batch_shape_, fake_log_prob_, fake_prob_, ] = self.evaluate([ x, fake_mvn.event_shape_tensor(), fake_mvn.batch_shape_tensor(), fake_mvn.log_prob(x), fake_mvn.prob(x), ]) # Ensure all other functions work as intended. self.assertAllEqual([5, 2, 2, 3, 3], x_.shape) self.assertAllEqual([2, 3, 3], fake_event_shape_) self.assertAllEqual([2], fake_batch_shape_) self.assertAllClose( actual_mvn_log_prob(x_), fake_log_prob_, atol=0., rtol=1e-6) # With this many dimensions and samples, the direct space probability # may underflow. self.assertAllClose( np.exp(actual_mvn_log_prob(x_)), fake_prob_, atol=1e-12, rtol=1e-5)
def rank_only_shape(mindims, maxdims): return hps.integers(min_value=mindims, max_value=maxdims).map(tf.TensorShape(None).with_rank)
def _add_batch(shape): return tf.TensorShape([batch_size] + shape.as_list())
def trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn=None, static_trace_allocation_size=None, condition_fn=None, parallel_iterations=10, name=None): """A simplified version of `tf.scan` that has configurable tracing. This function repeatedly calls `loop_fn(state, elem)`, where `state` is the `initial_state` during the first iteration, and the return value of `loop_fn` for every iteration thereafter. `elem` is a slice of `elements` along the first dimension, accessed in order. Additionally, it calls `trace_fn` on the return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are stacked and returned from this function, such that the first dimension of those `Tensor`s matches the size of `elems`. Args: loop_fn: A callable that takes in a `Tensor` or a nested collection of `Tensor`s with the same structure as `initial_state`, a slice of `elems` and returns the same structure as `initial_state`. initial_state: A `Tensor` or a nested collection of `Tensor`s passed to `loop_fn` in the first iteration. elems: A `Tensor` that is split along the first dimension and each element of which is passed to `loop_fn`. trace_fn: A callable that takes in the return value of `loop_fn` and returns a `Tensor` or a nested collection of `Tensor`s. trace_criterion_fn: Optional callable that takes in the return value of `loop_fn` and returns a boolean `Tensor` indicating whether to trace it. If `None`, all steps are traced. Default value: `None`. static_trace_allocation_size: Optional Python `int` size of trace to allocate statically. This should be an upper bound on the number of steps traced and is used only when the length cannot be statically inferred (for example, if a `trace_criterion_fn` is specified). It is primarily intended for contexts where static shapes are required, such as in XLA-compiled code. Default value: `None`. condition_fn: Python `callable` additional loop termination condition, with signature `should_continue = condition_fn(step, state, num_traced, trace)`; returning `False` will terminate early and not scan over all of `elems`. Default value: `None`, which means no additional termination condition. parallel_iterations: Passed to the internal `tf.while_loop`. name: Name scope used in this function. Default: 'trace_scan'. Returns: final_state: The final return value of `loop_fn`. trace: The same structure as the return value of `trace_fn`, but with each `Tensor` being a stack of the corresponding `Tensors` in the return value of `trace_fn` for each slice of `elems`. """ with tf.name_scope(name or 'trace_scan'), tf1.variable_scope( tf1.get_variable_scope()) as vs: if vs.caching_device is None and not tf.executing_eagerly(): vs.set_caching_device(lambda op: op.device) initial_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='initial_state'), initial_state, expand_composites=True) elems = tf.convert_to_tensor(elems, name='elems') length = ps.size0(elems) # This is an TensorArray in part because of XLA, which had trouble with # non-statically known indices. I.e. elems[i] errored, but # elems_array.read(i) worked. elems_array = tf.TensorArray(elems.dtype, size=length, element_shape=elems.shape[1:]) elems_array = elems_array.unstack(elems) # Initialize trace arrays. if trace_criterion_fn is None and condition_fn is None: dynamic_size, initial_size = tf.is_tensor(length), length elif static_trace_allocation_size is not None: dynamic_size, initial_size = False, static_trace_allocation_size elif JAX_MODE or (not tf.executing_eagerly() and control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): dynamic_size, initial_size = False, length else: dynamic_size, initial_size = True, 0 initial_trace = trace_fn(initial_state) flat_initial_trace = tf.nest.flatten(initial_trace, expand_composites=True) trace_arrays = [] for trace_elt in flat_initial_trace: trace_arrays.append( tf.TensorArray(trace_elt.dtype, size=initial_size, dynamic_size=dynamic_size, element_shape=trace_elt.shape)) # Helper for writing a (structured) state to (structured) arrays. def trace_one_step(num_steps_traced, trace_arrays, state): return [ ta.write(num_steps_traced, x) for ta, x in zip( trace_arrays, tf.nest.flatten(trace_fn(state), expand_composites=True)) ] def _body(i, state, num_steps_traced, trace_arrays): elem = elems_array.read(i) state = loop_fn(state, elem) trace_arrays, num_steps_traced = ps.cond( trace_criterion_fn(state) if trace_criterion_fn else True, lambda: ( trace_one_step(num_steps_traced, trace_arrays, state), # pylint: disable=g-long-lambda num_steps_traced + 1), lambda: (trace_arrays, num_steps_traced)) return i + 1, state, num_steps_traced, trace_arrays if condition_fn is None: cond = lambda i, *_: i < length else: cond = lambda i, *rest: (i < length) & condition_fn(i, *rest) _, final_state, _, trace_arrays = tf.while_loop( cond=cond, body=_body, loop_vars=(0, initial_state, 0, trace_arrays), parallel_iterations=parallel_iterations) # unflatten stacked_trace = tf.nest.pack_sequence_as( initial_trace, [ta.stack() for ta in trace_arrays], expand_composites=True) # Restore the static length if we know it. static_length = tf.TensorShape(None if dynamic_size else initial_size) def _merge_static_length(x): tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:])) return x stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace, expand_composites=True) return final_state, stacked_trace
def _event_shape(self): dimension = self.scale_operator.domain_dimension return tf.TensorShape([dimension, dimension])
def event_shape(self): # Present as a vector-valued distribution. return tf.TensorShape([1])
def compute_output_shape(self, input_shape): if self.output_mode == INT: return input_shape depth = (self.max_tokens if self.pad_to_max_tokens else self._frozen_vocab_size) return tf.TensorShape([input_shape[0], depth])
def testParamStaticShapes(self): sample_shape = [10, 3, 4] self._testParamStaticShapes(sample_shape, sample_shape) self._testParamStaticShapes(tf.TensorShape(sample_shape), sample_shape)
def convert_to_batch_shape(s): # Prepend a 1 for the batch dimension; for recurrent # variational dropout we use the same dropout mask for all # batch elements. return tf.concat(([1], tf.TensorShape(s).as_list()), 0)
return rng.uniform(shape=shape, dtype=sig.dtype, minval=minval, maxval=maxval) return math_lib.nested_map(f, input_sig) def Mod(n): # pylint: disable=invalid-name return layers.Fn("Mod", lambda x: x % n) # Format: # (trax-layer maker, input shapes, input dtype, can handle None batch size?) _LAYERS = [ (lambda: layers.Dense(3), tf.TensorShape([4]), onp.float32, True), (mlp.PureMLP, tf.TensorShape([4]), onp.float32, False), (lambda: layers.Serial(Mod(8), transformer.TransformerLM(8)), tf.TensorShape([4]), onp.int32, False), ] _RNG_UPDATERS = [ lambda x: x, lambda rng: math_lib.random.split(rng, 1)[0], ] # Needs tf.test.TestCase for `assertAllClose` and `get_temp_dir` class Trax2KerasTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters([ {
def _batch_shape(self): return tf.TensorShape([])
def _testMVN(self, base_distribution_class, base_distribution_kwargs, batch_shape=(), event_shape=(), not_implemented_message=None): # Overriding shapes must be compatible w/bijector; most bijectors are # batch_shape agnostic and only care about event_ndims. # In the case of `Affine`, if we got it wrong then it would fire an # exception due to incompatible dimensions. batch_shape_pl = tf1.placeholder_with_default( input=np.int32(batch_shape), shape=None, name='dynamic_batch_shape') event_shape_pl = tf1.placeholder_with_default( input=np.int32(event_shape), shape=None, name='dynamic_event_shape') fake_mvn_dynamic = self._cls()( distribution=base_distribution_class( validate_args=True, **base_distribution_kwargs), bijector=tfb.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=batch_shape_pl, event_shape=event_shape_pl, validate_args=True) fake_mvn_static = self._cls()( distribution=base_distribution_class( validate_args=True, **base_distribution_kwargs), bijector=tfb.Affine(shift=self._shift, scale_tril=self._tril), batch_shape=batch_shape, event_shape=event_shape, validate_args=True) actual_mean = np.tile(self._shift, [2, 1]) # Affine elided this tile. actual_cov = np.matmul(self._tril, np.transpose(self._tril, [0, 2, 1])) def actual_mvn_log_prob(x): return np.concatenate([[ # pylint: disable=g-complex-comprehension stats.multivariate_normal(actual_mean[i], actual_cov[i]).logpdf(x[:, i, :]) ] for i in range(len(actual_cov))]).T actual_mvn_entropy = np.concatenate( [[stats.multivariate_normal(actual_mean[i], actual_cov[i]).entropy()] for i in range(len(actual_cov))]) self.assertAllEqual([3], fake_mvn_static.event_shape) self.assertAllEqual([2], fake_mvn_static.batch_shape) if not tf.executing_eagerly(): self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.event_shape) self.assertAllEqual(tf.TensorShape(None), fake_mvn_dynamic.batch_shape) x = self.evaluate(fake_mvn_static.sample(5, seed=tfp_test_util.test_seed())) for unsupported_fn in (fake_mvn_static.log_cdf, fake_mvn_static.cdf, fake_mvn_static.survival_function, fake_mvn_static.log_survival_function): with self.assertRaisesRegexp(NotImplementedError, not_implemented_message): unsupported_fn(x) num_samples = 7e3 for fake_mvn in [fake_mvn_static, fake_mvn_dynamic]: # Ensure sample works by checking first, second moments. y = fake_mvn.sample(int(num_samples), seed=tfp_test_util.test_seed()) x = y[0:5, ...] sample_mean = tf.reduce_mean(input_tensor=y, axis=0) centered_y = tf.transpose(a=y - sample_mean, perm=[1, 2, 0]) sample_cov = tf.matmul( centered_y, centered_y, transpose_b=True) / num_samples [ sample_mean_, sample_cov_, x_, fake_event_shape_, fake_batch_shape_, fake_log_prob_, fake_prob_, fake_mean_, fake_entropy_, ] = self.evaluate([ sample_mean, sample_cov, x, fake_mvn.event_shape_tensor(), fake_mvn.batch_shape_tensor(), fake_mvn.log_prob(x), fake_mvn.prob(x), fake_mvn.mean(), fake_mvn.entropy(), ]) self.assertAllClose(actual_mean, sample_mean_, atol=0.1, rtol=0.1) self.assertAllClose(actual_cov, sample_cov_, atol=0., rtol=0.1) # Ensure all other functions work as intended. self.assertAllEqual([5, 2, 3], x_.shape) self.assertAllEqual([3], fake_event_shape_) self.assertAllEqual([2], fake_batch_shape_) self.assertAllClose( actual_mvn_log_prob(x_), fake_log_prob_, atol=0., rtol=1e-6) self.assertAllClose( np.exp(actual_mvn_log_prob(x_)), fake_prob_, atol=0., rtol=1e-5) self.assertAllClose(actual_mean, fake_mean_, atol=0., rtol=1e-6) self.assertAllClose(actual_mvn_entropy, fake_entropy_, atol=0., rtol=1e-6)
def _event_shape(self): return tf.TensorShape(self.image_shape)
def draw_valid_slices(data, batch_shape): """Samples a legal (possibly empty) slice for shape batch_shape.""" # We build up a list of slices in several stages: # 1. Choose 0 to batch_rank slices to come before an Ellipsis (...). # 2. Decide whether or not to add an Ellipsis; if using, updating the indexing # used (e.g. batch_shape[i]) to identify safe bounds. # 3. Choose 0 to [remaining_dims] slices to come last. # 4. Decide where to insert between 0 and 7 newaxis slices. batch_shape = tf.TensorShape(batch_shape).as_list() slices = [] batch_rank = len(batch_shape) arbitrary_slices = hps.tuples( hps.one_of(hps.just(None), hps.integers(min_value=-100, max_value=100)), hps.one_of(hps.just(None), hps.integers(min_value=-100, max_value=100)), hps.one_of( hps.just(None), hps.integers(min_value=-100, max_value=100).filter( lambda x: x != 0))).map(lambda tup: slice(*tup)) # 1. Choose 0 to batch_rank slices to come before an Ellipsis (...). nslc_before_ellipsis = data.draw( hps.integers(min_value=0, max_value=batch_rank)) for i in range(nslc_before_ellipsis): slc = data.draw( hps.one_of(hps.integers(min_value=0, max_value=batch_shape[i] - 1), arbitrary_slices)) slices.append(slc) # 2. Decide whether or not to add an Ellipsis; if using, updating the indexing # used (e.g. batch_shape[i]) to identify safe bounds. has_ellipsis = data.draw(hps.booleans().map(lambda x: (Ellipsis, x)))[1] nslc_after_ellipsis = data.draw( hps.integers(min_value=0, max_value=batch_rank - nslc_before_ellipsis)) if has_ellipsis: slices.append(Ellipsis) remain_start, remain_end = (batch_rank - nslc_after_ellipsis, batch_rank) else: remain_start = nslc_before_ellipsis remain_end = nslc_before_ellipsis + nslc_after_ellipsis # 3. Choose 0 to [remaining_dims] slices to come last. for i in range(remain_start, remain_end): slc = data.draw( hps.one_of(hps.integers(min_value=0, max_value=batch_shape[i] - 1), arbitrary_slices)) slices.append(slc) # 4. Decide where to insert between 0 and 7 newaxis slices. newaxis_positions = data.draw( hps.lists(hps.integers(min_value=0, max_value=len(slices)), max_size=7)) for i in sorted(newaxis_positions, reverse=True): slices.insert(i, tf.newaxis) slices = tuple(slices) # Since `d[0]` ==> `d.__getitem__(0)` instead of `d.__getitem__((0,))`; # and similarly `d[:3]` ==> `d.__getitem__(slice(None, 3))` instead of # `d.__getitem__((slice(None, 3),))`; it is useful to test such scenarios. if len(slices) == 1 and data.draw(hps.booleans()): # Sometimes only a single item non-tuple. return slices[0] return slices
def _entropy(self, **kwargs): return self._call_and_reshape_output(self.distribution.entropy, [], [tf.TensorShape([])], extra_kwargs=kwargs)
def _event_shape(self): return tf.TensorShape([self.dimension, self.dimension])
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name='HiddenMarkovModel'): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. An integer valued tensor. The number of transitions is `num_steps - 1`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._num_steps = tensor_util.convert_nonref_to_tensor(num_steps) self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: if np.ndim(num_steps_) != 0: raise ValueError( '`num_steps` must be a scalar but it has rank {}'. format(np.ndim(num_steps_))) else: self._static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self._static_event_shape = tf.TensorShape([None]).concatenate( self._observation_distribution.event_shape) self._static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name) # pylint: enable=protected-access self._parameters = parameters
def _batch_shape(self): scalar_shape = tf.TensorShape([]) return tf.broadcast_static_shape( scalar_shape if self.amplitude is None else self.amplitude.shape, scalar_shape if self.length_scale is None else self.length_scale.shape)
def testParamStaticShapes(self): sample_shape = [7] self._testParamShapes(sample_shape) self._testParamShapes(tf.TensorShape(sample_shape))
def _event_shape(self): if tensorshape_util.rank(self.samples.shape) is None: return tf.TensorShape(None) return self.samples.shape[self._samples_axis + 1:]
def _batch_shape(self): return tf.TensorShape(self._batch_shape_tuple)
def _event_shape(self): return tf.TensorShape([])
def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, input_data=None, expected_output=None, expected_output_dtype=None, expected_output_shape=None, validate_training=True, adapt_data=None, custom_objects=None, test_harness=None, supports_masking=None): """Test routine for a layer with a single input and single output. Args: layer_cls: Layer class object. kwargs: Optional dictionary of keyword arguments for instantiating the layer. input_shape: Input shape tuple. input_dtype: Data type of the input data. input_data: Numpy array of input data. expected_output: Numpy array of the expected output. expected_output_dtype: Data type expected for the output. expected_output_shape: Shape tuple for the expected shape of the output. validate_training: Whether to attempt to validate training on this layer. This might be set to False for non-differentiable layers that output string or integer values. adapt_data: Optional data for an 'adapt' call. If None, adapt() will not be tested for this layer. This is only relevant for PreprocessingLayers. custom_objects: Optional dictionary mapping name strings to custom objects in the layer class. This is helpful for testing custom layers. test_harness: The Tensorflow test, if any, that this function is being called in. supports_masking: Optional boolean to check the `supports_masking` property of the layer. If None, the check will not be performed. Returns: The output data (Numpy array) returned by the layer, for additional checks to be done by the calling code. Raises: ValueError: if `input_shape is None`. """ if input_data is None: if input_shape is None: raise ValueError('input_shape is None') if not input_dtype: input_dtype = 'float32' input_data_shape = list(input_shape) for i, e in enumerate(input_data_shape): if e is None: input_data_shape[i] = np.random.randint(1, 4) input_data = 10 * np.random.random(input_data_shape) if input_dtype[:5] == 'float': input_data -= 0.5 input_data = input_data.astype(input_dtype) elif input_shape is None: input_shape = input_data.shape if input_dtype is None: input_dtype = input_data.dtype if expected_output_dtype is None: expected_output_dtype = input_dtype if tf.as_dtype(expected_output_dtype) == tf.string: if test_harness: assert_equal = test_harness.assertAllEqual else: assert_equal = string_test else: if test_harness: assert_equal = test_harness.assertAllClose else: assert_equal = numeric_test # instantiation kwargs = kwargs or {} layer = layer_cls(**kwargs) if (supports_masking is not None and layer.supports_masking != supports_masking): raise AssertionError( 'When testing layer %s, the `supports_masking` property is %r' 'but expected to be %r.\nFull kwargs: %s' % (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs)) # Test adapt, if data was passed. if adapt_data is not None: layer.adapt(adapt_data) # test get_weights , set_weights at layer level weights = layer.get_weights() layer.set_weights(weights) # test and instantiation from weights if 'weights' in tf_inspect.getargspec(layer_cls.__init__): kwargs['weights'] = weights layer = layer_cls(**kwargs) # test in functional API x = layers.Input(shape=input_shape[1:], dtype=input_dtype) y = layer(x) if backend.dtype(y) != expected_output_dtype: raise AssertionError( 'When testing layer %s, for input %s, found output ' 'dtype=%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, backend.dtype(y), expected_output_dtype, kwargs)) def assert_shapes_equal(expected, actual): """Asserts that the output shape from the layer matches the actual shape.""" if len(expected) != len(actual): raise AssertionError( 'When testing layer %s, for input %s, found output_shape=' '%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual, expected, kwargs)) for expected_dim, actual_dim in zip(expected, actual): if isinstance(expected_dim, tf.compat.v1.Dimension): expected_dim = expected_dim.value if isinstance(actual_dim, tf.compat.v1.Dimension): actual_dim = actual_dim.value if expected_dim is not None and expected_dim != actual_dim: raise AssertionError( 'When testing layer %s, for input %s, found output_shape=' '%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual, expected, kwargs)) if expected_output_shape is not None: assert_shapes_equal(tf.TensorShape(expected_output_shape), y.shape) # check shape inference model = models.Model(x, y) computed_output_shape = tuple( layer.compute_output_shape(tf.TensorShape(input_shape)).as_list()) computed_output_signature = layer.compute_output_signature( tf.TensorSpec(shape=input_shape, dtype=input_dtype)) actual_output = model.predict(input_data) actual_output_shape = actual_output.shape assert_shapes_equal(computed_output_shape, actual_output_shape) assert_shapes_equal(computed_output_signature.shape, actual_output_shape) if computed_output_signature.dtype != actual_output.dtype: raise AssertionError( 'When testing layer %s, for input %s, found output_dtype=' '%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual_output.dtype, computed_output_signature.dtype, kwargs)) if expected_output is not None: assert_equal(actual_output, expected_output) # test serialization, weight setting at model level model_config = model.get_config() recovered_model = models.Model.from_config(model_config, custom_objects) if model.weights: weights = model.get_weights() recovered_model.set_weights(weights) output = recovered_model.predict(input_data) assert_equal(output, actual_output) # test training mode (e.g. useful for dropout tests) # Rebuild the model to avoid the graph being reused between predict() and # See b/120160788 for more details. This should be mitigated after 2.0. layer_weights = layer.get_weights( ) # Get the layer weights BEFORE training. if validate_training: model = models.Model(x, layer(x)) if _thread_local_data.run_eagerly is not None: model.compile('rmsprop', 'mse', weighted_metrics=['acc'], run_eagerly=should_run_eagerly()) else: model.compile('rmsprop', 'mse', weighted_metrics=['acc']) model.train_on_batch(input_data, actual_output) # test as first layer in Sequential API layer_config = layer.get_config() layer_config['batch_input_shape'] = input_shape layer = layer.__class__.from_config(layer_config) # Test adapt, if data was passed. if adapt_data is not None: layer.adapt(adapt_data) model = models.Sequential() model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype)) model.add(layer) layer.set_weights(layer_weights) actual_output = model.predict(input_data) actual_output_shape = actual_output.shape for expected_dim, actual_dim in zip(computed_output_shape, actual_output_shape): if expected_dim is not None: if expected_dim != actual_dim: raise AssertionError( 'When testing layer %s **after deserialization**, ' 'for input %s, found output_shape=' '%s but expected to find inferred shape %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual_output_shape, computed_output_shape, kwargs)) if expected_output is not None: assert_equal(actual_output, expected_output) # test serialization, weight setting at model level model_config = model.get_config() recovered_model = models.Sequential.from_config(model_config, custom_objects) if model.weights: weights = model.get_weights() recovered_model.set_weights(weights) output = recovered_model.predict(input_data) assert_equal(output, actual_output) # for further checks in the caller function return actual_output
def compute_output_shape(self, input_shape): return tf.TensorShape((input_shape[0], self._target_pdf.shape[0]))
def _event_shape(self): dimension = self._scale.domain_dimension return tf.TensorShape([dimension, dimension])