コード例 #1
0
ファイル: builder_test.py プロジェクト: jackd/more-keras
 def test_batched_tensor(self):
     b = builder.MetaNetworkBuilder()
     with b:
         x = b.prebatch_input(shape=(4, 5), dtype=tf.float32)
         batched = builder.batched(x)
         self.assertEqual(batched.shape.as_list(), [None, 4, 5])
     b.preprocessor(())
コード例 #2
0
ファイル: builder_test.py プロジェクト: jackd/more-keras
 def test_batched_ragged(self):
     b = builder.MetaNetworkBuilder()
     with b:
         x = b.prebatch_input(shape=(None, 5), dtype=tf.float32)
         batched = builder.batched(x)
         self.assertTrue(isinstance(batched, tf.RaggedTensor))
         self.assertEqual(batched.shape.as_list(), [None, None, 5])
     b.preprocessor(())
コード例 #3
0
ファイル: builder_test.py プロジェクト: jackd/more-keras
 def _test_batched_ragged_ragged(self, row_split_vals, row_splits_size):
     from tensorflow.python.ops.ragged import ragged_tensor_shape as rts  # pylint: disable=no-name-in-module
     rts.broadcast_to
     b = builder.MetaNetworkBuilder()
     with b:
         values = b.prebatch_input(shape=(None, 5), dtype=tf.float32)
         # values = tf.keras.layers.Lambda(lambda x: 2*x)(values)
         row_splits = b.prebatch_input(shape=(row_splits_size,),
                                       dtype=tf.int64)
         x = tf.RaggedTensor.from_row_splits(values, row_splits)
         b.as_batched_model_input(x)
     b.preprocessor(())
コード例 #4
0
ファイル: trainers.py プロジェクト: jackd/more-keras
    def _rebuild_model(self, input_spec, output_spec):
        from more_keras.meta_models import builder as b
        with b.MetaNetworkBuilder() as builder:
            self._model = self._model_fn(input_spec, output_spec)
            self.compile_model()

            # build preprocessor
            spec = self.problem.element_spec
            if len(spec) == 2:
                labels_spec = spec[1]
                weights_spec = None
            else:
                labels_spec, weights_spec = spec[1:]

            labels, weights = tf.nest.map_structure(
                lambda spec: None if spec is None else builder.batched(
                    builder.prebatch_input(shape=spec.shape, dtype=spec.dtype)),
                (labels_spec, weights_spec))
            self._preprocessor = builder.preprocessor(labels, weights)
コード例 #5
0
ファイル: neigh_test.py プロジェクト: jackd/deep-cloud
 def test_transposed_consistent(self):
     with b.MetaNetworkBuilder():
         self._test_transposed_consistent()
コード例 #6
0
ファイル: neigh_test.py プロジェクト: jackd/deep-cloud
 def test_sampled_consistent(self):
     with b.MetaNetworkBuilder():
         self._test_sampled_consistent()