Ejemplo n.º 1
0
  def setUp(self):
    super(ComplexityRegularizedEnsemblerTest, self).setUp()

    self._optimizer = tf_compat.v1.train.GradientDescentOptimizer(
        learning_rate=.1)
    self.easy_ensembler = ensemble.ComplexityRegularizedEnsembler(
        optimizer=self._optimizer)

    mock.patch.object(tf.train, 'load_variable', autospec=False).start()
    mock.patch.object(
        tf.compat.v1.train, 'load_variable', autospec=False).start()
    mock.patch.object(
        tf.compat.v2.train, 'load_variable', autospec=False).start()

    def load_variable(checkpoint_dir, name):
      self.assertEqual(checkpoint_dir, 'fake_checkpoint_dir')
      var = tf_compat.v1.get_variable(
          name='fake_loaded_variable_' + name, initializer=1.)
      with self.test_session() as sess:
        sess.run(var.initializer)
        return var

    tf.train.load_variable.side_effect = load_variable

    self.summary = _FakeSummary()
Ejemplo n.º 2
0
 def test_build_train_op_no_op(self):
   train_op = ensemble.ComplexityRegularizedEnsembler().build_train_op(
       *[None] * 7)  # arguments unused
   if tf.executing_eagerly():
     self.assertIsNone(train_op)
   else:
     self.assertEqual(train_op.type, tf.no_op().type)
Ejemplo n.º 3
0
    def setUp(self):
        super(ComplexityRegularizedEnsemblerTest, self).setUp()

        self._optimizer = tf_compat.v1.train.GradientDescentOptimizer(
            learning_rate=.1)
        self.easy_ensembler = ensemble.ComplexityRegularizedEnsembler(
            optimizer=self._optimizer)

        mock.patch.object(tf.train, 'load_variable', autospec=False).start()
        mock.patch.object(tf.compat.v1.train, 'load_variable',
                          autospec=False).start()
        mock.patch.object(tf.compat.v2.train, 'load_variable',
                          autospec=False).start()
        mock.patch.object(ensemble.ComplexityRegularizedEnsembler,
                          '_load_variable',
                          autospec=False).start()

        def _load_variable(var, previous_iteration_checkpoint):
            del var  # unused
            assert previous_iteration_checkpoint is not None
            return 1.0

        complexity_regularized_ensembler = ensemble.ComplexityRegularizedEnsembler
        complexity_regularized_ensembler._load_variable.side_effect = _load_variable

        self.summary = _FakeSummary()
Ejemplo n.º 4
0
 def test_build_train_op(self):
   dummy_weight = tf.Variable(0., name='dummy_weight')
   dummy_loss = dummy_weight * 2.
   ensembler = ensemble.ComplexityRegularizedEnsembler(
       optimizer=tf_compat.v1.train.GradientDescentOptimizer(.1))
   train_op = ensembler.build_train_op(
       self._build_easy_ensemble([self._build_subnetwork()]), dummy_loss,
       [dummy_weight], *[None] * 4)
   with tf_compat.v1.Session() as sess:
     sess.run(tf_compat.v1.global_variables_initializer())
     sess.run(train_op)
     self.assertAllClose(-.2, sess.run(dummy_weight))
Ejemplo n.º 5
0
 def test_build_train_op(self):
   with context.graph_mode():
     dummy_weight = tf.Variable(0., name='dummy_weight')
     dummy_loss = dummy_weight * 2.
     ensembler = ensemble.ComplexityRegularizedEnsembler(
         optimizer=tf_compat.v1.train.GradientDescentOptimizer(.1))
     train_op = ensembler.build_train_op(
         self._build_easy_ensemble([self._build_subnetwork()]), dummy_loss,
         [dummy_weight], *[None] * 4)
     config = tf.compat.v1.ConfigProto(
         gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
     with tf_compat.v1.Session(config=config) as sess:
       sess.run(tf_compat.v1.global_variables_initializer())
       sess.run(train_op)
       self.assertAllClose(-.2, sess.run(dummy_weight))
Ejemplo n.º 6
0
    def setUp(self):
        super(ComplexityRegularizedEnsemblerTest, self).setUp()

        self._optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1)
        self.easy_ensembler = ensemble.ComplexityRegularizedEnsembler(
            optimizer=self._optimizer)

        mock.patch.object(self._optimizer, 'minimize', autospec=True).start()

        mock.patch.object(tf.contrib.framework, 'load_variable',
                          autospec=True).start()

        def load_variable(checkpoint_dir, name):
            self.assertEqual(checkpoint_dir, 'fake_checkpoint_dir')
            return tf.Variable(initial_value=1.,
                               name='fake_loaded_variable_' + name)

        tf.contrib.framework.load_variable.side_effect = load_variable

        self.summary = _FakeSummary()
Ejemplo n.º 7
0
 def test_build_train_op_no_op(self):
     self.assertEqual(
         ensemble.ComplexityRegularizedEnsembler().build_train_op(
             *[None] * 7).type,  # arguments unused
         tf.no_op().type)
Ejemplo n.º 8
0
    def test_build_ensemble(
            self,
            mixture_weight_type=ensemble.MixtureWeightType.SCALAR,
            mixture_weight_initializer=None,
            warm_start_mixture_weights=False,
            adanet_lambda=0.,
            adanet_beta=0.,
            multi_head=None,
            use_bias=False,
            num_subnetworks=1,
            num_previous_ensemble_subnetworks=0,
            expected_complexity_regularization=0.,
            expected_summary_scalars=None):
        model_dir = None
        if warm_start_mixture_weights:
            model_dir = 'fake_checkpoint_dir'
        ensembler = ensemble.ComplexityRegularizedEnsembler(
            optimizer=self._optimizer,
            mixture_weight_type=mixture_weight_type,
            mixture_weight_initializer=mixture_weight_initializer,
            warm_start_mixture_weights=warm_start_mixture_weights,
            model_dir=model_dir,
            adanet_lambda=adanet_lambda,
            adanet_beta=adanet_beta,
            use_bias=use_bias)

        with tf.variable_scope('dummy_adanet_scope_iteration_0'):
            previous_ensemble_subnetworks_all = [
                self._build_subnetwork(multi_head),
                self._build_subnetwork(multi_head)
            ]

            previous_ensemble = self._build_easy_ensemble(
                previous_ensemble_subnetworks_all)

        with tf.variable_scope('dummy_adanet_scope_iteration_1'):
            subnetworks_pool = [
                self._build_subnetwork(multi_head),
                self._build_subnetwork(multi_head),
            ]

            subnetworks = subnetworks_pool[:num_subnetworks]

            previous_ensemble_subnetworks = previous_ensemble_subnetworks_all[:(
                num_previous_ensemble_subnetworks)]

            self.summary.clear_scalars()

            built_ensemble = ensembler.build_ensemble(
                subnetworks=subnetworks,
                previous_ensemble_subnetworks=previous_ensemble_subnetworks,
                features=None,
                labels=None,
                logits_dimension=None,
                training=None,
                iteration_step=None,
                summary=self.summary,
                previous_ensemble=previous_ensemble)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            [summary_scalars, complexity_regularization] = sess.run([
                self.summary.scalars, built_ensemble.complexity_regularization
            ])

            if expected_summary_scalars:
                for key in expected_summary_scalars.keys():
                    self._assert_list_almost_equal(
                        summary_scalars[key], expected_summary_scalars[key])

            self.assertEqual(
                [l.subnetwork for l in built_ensemble.weighted_subnetworks],
                previous_ensemble_subnetworks + subnetworks)

            self.assertAlmostEqual(complexity_regularization,
                                   expected_complexity_regularization)
            self.assertIsNotNone(sess.run(built_ensemble.logits))
Ejemplo n.º 9
0
 def test_build_train_op_no_op(self):
   with context.graph_mode():
     train_op = ensemble.ComplexityRegularizedEnsembler().build_train_op(
         *[None] * 7)  # arguments unused
     self.assertEqual(train_op.type, tf.no_op().type)