def test_batched_tensor(self): b = builder.MetaNetworkBuilder() with b: x = b.prebatch_input(shape=(4, 5), dtype=tf.float32) batched = builder.batched(x) self.assertEqual(batched.shape.as_list(), [None, 4, 5]) b.preprocessor(())
def test_batched_ragged(self): b = builder.MetaNetworkBuilder() with b: x = b.prebatch_input(shape=(None, 5), dtype=tf.float32) batched = builder.batched(x) self.assertTrue(isinstance(batched, tf.RaggedTensor)) self.assertEqual(batched.shape.as_list(), [None, None, 5]) b.preprocessor(())
def _test_batched_ragged_ragged(self, row_split_vals, row_splits_size): from tensorflow.python.ops.ragged import ragged_tensor_shape as rts # pylint: disable=no-name-in-module rts.broadcast_to b = builder.MetaNetworkBuilder() with b: values = b.prebatch_input(shape=(None, 5), dtype=tf.float32) # values = tf.keras.layers.Lambda(lambda x: 2*x)(values) row_splits = b.prebatch_input(shape=(row_splits_size,), dtype=tf.int64) x = tf.RaggedTensor.from_row_splits(values, row_splits) b.as_batched_model_input(x) b.preprocessor(())
def _rebuild_model(self, input_spec, output_spec): from more_keras.meta_models import builder as b with b.MetaNetworkBuilder() as builder: self._model = self._model_fn(input_spec, output_spec) self.compile_model() # build preprocessor spec = self.problem.element_spec if len(spec) == 2: labels_spec = spec[1] weights_spec = None else: labels_spec, weights_spec = spec[1:] labels, weights = tf.nest.map_structure( lambda spec: None if spec is None else builder.batched( builder.prebatch_input(shape=spec.shape, dtype=spec.dtype)), (labels_spec, weights_spec)) self._preprocessor = builder.preprocessor(labels, weights)
def test_transposed_consistent(self): with b.MetaNetworkBuilder(): self._test_transposed_consistent()
def test_sampled_consistent(self): with b.MetaNetworkBuilder(): self._test_sampled_consistent()