def test_should_error_out_for_not_recognized_args(self): head = tf.contrib.estimator.binary_classification_head( loss_reduction=tf.losses.Reduction.SUM) def metric_fn(features, not_recognized): _, _ = features, not_recognized return {} with self.assertRaisesRegexp(ValueError, "not_recognized"): _EnsembleBuilder(head, MixtureWeightType.SCALAR, metric_fn=metric_fn)
def test_init_error(self): with self.assertRaises(ValueError): _EnsembleBuilder( head=tf.contrib.estimator.binary_classification_head( loss_reduction=tf.losses.Reduction.SUM), mixture_weight_type=MixtureWeightType.MATRIX, mixture_weight_initializer=tf.zeros_initializer(), warm_start_mixture_weights=True, checkpoint_dir=None, adanet_lambda=0., adanet_beta=0., use_bias=True)
def _make_metrics(sess, metric_fn): head = tf.contrib.estimator.binary_classification_head( loss_reduction=tf.losses.Reduction.SUM) builder = _EnsembleBuilder( head, MixtureWeightType.SCALAR, metric_fn=metric_fn) features = {"x": tf.constant([[1.], [2.]])} labels = tf.constant([0, 1]) ensemble_spec = builder.append_new_subnetwork( ensemble_name="test", ensemble_spec=None, subnetwork_builder=_Builder( lambda unused0, unused1: tf.no_op(), lambda unused0, unused1: tf.no_op(), use_logits_last_layer=True), iteration_number=0, iteration_step=1, summary=_FakeSummary(), features=features, mode=tf.estimator.ModeKeys.EVAL, labels=labels) sess.run((tf.global_variables_initializer(), tf.local_variables_initializer())) metrics = sess.run(ensemble_spec.eval_metric_ops) return {k: metrics[k][1] for k in metrics}
def _make_metrics(sess, metric_fn): head = tf.contrib.estimator.binary_classification_head( loss_reduction=tf.losses.Reduction.SUM) builder = _EnsembleBuilder(head, MixtureWeightType.SCALAR, metric_fn=metric_fn) features = {"x": tf.constant([[1.], [2.]])} labels = tf.constant([0, 1]) ensemble_spec = builder.build_ensemble_spec( "fake_ensemble", [ WeightedSubnetwork(name=tf.constant("fake_weighted"), logits=[[1.], [2.]], weight=[1.], subnetwork=Subnetwork(logits=[[1.], [2.]], last_layer=[1.], complexity=1., persisted_tensors={})) ], summary=_FakeSummary(), bias=0., features=features, mode=tf.estimator.ModeKeys.EVAL, labels=labels, iteration_step=1.) sess.run( (tf.global_variables_initializer(), tf.local_variables_initializer())) metrics = sess.run(ensemble_spec.eval_metric_ops) return {k: metrics[k][1] for k in metrics}
def test_append_new_subnetwork( self, want_logits, want_complexity_regularization, want_loss=None, want_adanet_loss=None, want_mixture_weight_vars=None, adanet_lambda=0., adanet_beta=0., ensemble_spec_fn=lambda: None, use_bias=False, use_logits_last_layer=False, mixture_weight_type=MixtureWeightType.MATRIX, mixture_weight_initializer=tf.zeros_initializer(), warm_start_mixture_weights=True, subnetwork_builder_class=_Builder, mode=tf.estimator.ModeKeys.TRAIN): seed = 64 builder = _EnsembleBuilder( head=tf.contrib.estimator.binary_classification_head( loss_reduction=tf.losses.Reduction.SUM), mixture_weight_type=mixture_weight_type, mixture_weight_initializer=mixture_weight_initializer, warm_start_mixture_weights=warm_start_mixture_weights, checkpoint_dir=self.test_subdirectory, adanet_lambda=adanet_lambda, adanet_beta=adanet_beta, use_bias=use_bias) features = {"x": tf.constant([[1.], [2.]])} labels = tf.constant([0, 1]) def _subnetwork_train_op_fn(loss, var_list): self.assertEqual(2, len(var_list)) self.assertEqual(var_list, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("ensemble_test/iteration_step", tf.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf.summary.audio("audio", 1., 1.)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) def _mixture_weights_train_op_fn(loss, var_list): self.assertEqual(want_mixture_weight_vars, len(var_list)) self.assertEqual(var_list, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("ensemble_test/iteration_step", tf.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf.summary.audio("audio", 1., 1.)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) ensemble_spec = builder.append_new_subnetwork( # Note: when ensemble_spec is not None and warm_start_mixture_weights # is True, we need to make sure that the bias and mixture weights are # already saved to the checkpoint_dir. ensemble_name="test", ensemble_spec=ensemble_spec_fn(), subnetwork_builder=subnetwork_builder_class( _subnetwork_train_op_fn, _mixture_weights_train_op_fn, use_logits_last_layer, seed), summary=_FakeSummary(), features=features, iteration_number=1, iteration_step=tf.train.get_or_create_global_step(), labels=labels, mode=mode) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) # Get the real global step outside a subnetwork's context. self.assertEqual("global_step", tf.train.get_global_step().op.name) # Get global tf.summary outside a subnetwork's context. self.assertNotEqual("fake_scalar", tf.summary.scalar("scalar", 1.)) self.assertNotEqual("fake_image", tf.summary.image("image", 1.)) self.assertNotEqual("fake_histogram", tf.summary.histogram( "histogram", 1.)) self.assertNotEqual("fake_audio", tf.summary.audio("audio", 1., 1.)) if mode == tf.estimator.ModeKeys.PREDICT: self.assertAllClose( want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) self.assertIsNone(ensemble_spec.loss) self.assertIsNone(ensemble_spec.adanet_loss) self.assertIsNone(ensemble_spec.train_op) self.assertIsNotNone(ensemble_spec.export_outputs) return # Verify that train_op works, previous loss should be greater than loss # after a train op. loss = sess.run(ensemble_spec.loss) for _ in range(3): sess.run(ensemble_spec.train_op) self.assertGreater(loss, sess.run(ensemble_spec.loss)) self.assertAllClose( want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) # Bias should learn a non-zero value when used. if use_bias: self.assertNotEqual(0., sess.run(ensemble_spec.ensemble.bias)) else: self.assertAlmostEqual(0., sess.run(ensemble_spec.ensemble.bias)) self.assertAlmostEqual( want_complexity_regularization, sess.run(ensemble_spec.complexity_regularization), places=3) self.assertAlmostEqual(want_loss, sess.run(ensemble_spec.loss), places=3) self.assertAlmostEqual( want_adanet_loss, sess.run(ensemble_spec.adanet_loss), places=3)
def __init__(self, head, subnetwork_generator, max_iteration_steps, mixture_weight_type=MixtureWeightType.SCALAR, mixture_weight_initializer=None, warm_start_mixture_weights=False, adanet_lambda=0., adanet_beta=0., evaluator=None, report_materializer=None, use_bias=False, replicate_ensemble_in_training=False, adanet_loss_decay=.9, worker_wait_timeout_secs=7200, model_dir=None, report_dir=None, config=None): """Initializes an `Estimator`. Regarding the options for `mixture_weight_type`: A `SCALAR` mixture weight is a rank 0 tensor. It performs an element- wise multiplication with its subnetwork's logits. This mixture weight is the simplest to learn, the quickest to train, and most likely to generalize well. A `VECTOR` mixture weight is a tensor of shape [k] where k is the ensemble's logits dimension as defined by `head`. It is similar to `SCALAR` in that it performs an element-wise multiplication with its subnetwork's logits, but is more flexible in learning a subnetworks's preferences per class. A `MATRIX` mixture weight is a tensor of shape [a, b] where a is the number of outputs from the subnetwork's `last_layer` and b is the number of outputs from the ensemble's `logits`. This weight matrix-multiplies the subnetwork's `last_layer`. This mixture weight offers the most flexibility and expressivity, allowing subnetworks to have outputs of different dimensionalities. However, it also has the most trainable parameters (a*b), and is therefore the most sensitive to learning rates and regularization. Args: head: A `tf.contrib.estimator.Head` instance for computing loss and evaluation metrics for every candidate. subnetwork_generator: The `adanet.subnetwork.Generator` which defines the candidate subnetworks to train and evaluate at every AdaNet iteration. max_iteration_steps: Total number of steps for which to train candidates per iteration. If `OutOfRange` or `StopIteration` occurs in the middle, training stops before `max_iteration_steps` steps. mixture_weight_type: The `adanet.MixtureWeightType` defining which mixture weight type to learn in the linear combination of subnetwork outputs. mixture_weight_initializer: The initializer for mixture_weights. When `None`, the default is different according to `mixture_weight_type`. `SCALAR` initializes to 1/N where N is the number of subnetworks in the ensemble giving a uniform average. `VECTOR` initializes each entry to 1/N where N is the number of subnetworks in the ensemble giving a uniform average. `MATRIX` uses `tf.zeros_initializer`. warm_start_mixture_weights: Whether, at the beginning of an iteration, to initialize the mixture weights of the subnetworks from the previous ensemble to their learned value at the previous iteration, as opposed to retraining them from scratch. Takes precedence over the value for `mixture_weight_initializer` for subnetworks from previous iterations. adanet_lambda: Float multiplier 'lambda' for applying L1 regularization to subnetworks' mixture weights 'w' in the ensemble proportional to their complexity. See Equation (4) in the AdaNet paper. adanet_beta: Float L1 regularization multiplier 'beta' to apply equally to all subnetworks' weights 'w' in the ensemble regardless of their complexity. See Equation (4) in the AdaNet paper. evaluator: An `Evaluator` for comparing `Ensemble` instances in evaluation mode using the training set, or a holdout set. When `None`, they are compared using a moving average of their `Ensemble`'s AdaNet loss during training. report_materializer: A `ReportMaterializer` for materializing a `Builder`'s `subnetwork.Reports` into `subnetwork.MaterializedReport`s. These reports are made available to the Generator at the next iteration, so that it can adapt its search space. When `None`, the Generators' `generate_candidates` method will receive empty Lists for their `previous_ensemble_reports` and `all_reports` arguments. use_bias: Whether to add a bias term to the ensemble's logits. Adding a bias allows the ensemble to learn a shift in the data, often leading to more stable training and better predictions. replicate_ensemble_in_training: Whether to freeze a copy of the ensembled subnetworks' subgraphs in training mode in addition to prediction mode. A copy of the subnetworks' subgraphs is always saved in prediction mode so that at prediction time, the ensemble and composing subnetworks are all in prediction mode. This argument only affects the outputs of the frozen subnetworks in the ensemble. When `False` and during candidate training, the frozen subnetworks in the ensemble are in prediction mode, so training-only ops like dropout are not applied to them. When `True` and training the candidates, the frozen subnetworks will be in training mode as well, so they will apply training-only ops like dropout. However when `True`, this doubles the amount of disk space required to store the frozen ensembles, and increases the preparation stage between boosting iterations. This argument is useful for regularizing learning mixture weights, or for making training-only side inputs available in subsequent iterations. For most use-cases, this should be `False`. adanet_loss_decay: Float decay for the exponential-moving-average of the AdaNet objective throughout training. This moving average is a data- driven way tracking the best candidate with only the training set. worker_wait_timeout_secs: Float number of seconds for workers to wait for chief to prepare the next iteration during distributed training. This is needed to prevent workers waiting indefinitely for a chief that may have crashed or been turned down. When the timeout is exceeded, the worker exits the train loop. In situations where the chief job is much slower than the worker jobs, this timeout should be increased. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. report_dir: Directory where the `adanet.subnetwork.MaterializedReport`s materialized by `report_materializer` would be saved. If `report_materializer` is None, this will not save anything. If `None` or empty string, defaults to "<model_dir>/report". config: `RunConfig` object to configure the runtime settings. Returns: An `Estimator` instance. Raises: ValueError: If `subnetwork_generator` is `None`. ValueError: If `max_iteration_steps` is <= 0. """ # TODO: Add argument to specify how many frozen graph # checkpoints to keep. if subnetwork_generator is None: raise ValueError("subnetwork_generator can't be None.") if max_iteration_steps <= 0.: raise ValueError("max_iteration_steps must be > 0.") self._adanet_loss_decay = adanet_loss_decay # Overwrite superclass's assert that members are not overwritten in order # to overwrite public methods. Note that we are doing something that is not # explicitly supported by the Estimator API and may break in the future. tf.estimator.Estimator._assert_members_are_not_overridden = staticmethod( lambda _: None) self._ensemble_builder = _EnsembleBuilder( head=head, mixture_weight_type=mixture_weight_type, mixture_weight_initializer=mixture_weight_initializer, warm_start_mixture_weights=warm_start_mixture_weights, adanet_lambda=adanet_lambda, adanet_beta=adanet_beta, use_bias=use_bias) candidate_builder = _CandidateBuilder( max_steps=max_iteration_steps, adanet_loss_decay=self._adanet_loss_decay) self._iteration_builder = _IterationBuilder(candidate_builder, self._ensemble_builder) self._freezer = _EnsembleFreezer() self._evaluation_checkpoint_path = None self._evaluator = evaluator self._report_materializer = report_materializer self._replicate_ensemble_in_training = replicate_ensemble_in_training self._worker_wait_timeout_secs = worker_wait_timeout_secs self._evaluation_name = None self._inside_adanet_training_loop = False # This `Estimator` is responsible for bookkeeping across iterations, and # for training the subnetworks in both a local and distributed setting. # Subclassing improves future-proofing against new private methods being # added to `tf.estimator.Estimator` that are expected to be callable by # external functions, such as in b/110435640. super(Estimator, self).__init__( model_fn=self._model_fn, params={ self._Keys.SUBNETWORK_GENERATOR: subnetwork_generator, }, config=config, model_dir=model_dir) # This is defined after base Estimator's init so that report_accessor can # use the same temporary model_dir as the underlying Estimator even if # model_dir is not provided. report_dir = report_dir or os.path.join(self._model_dir, "report") self._report_accessor = _ReportAccessor(report_dir)
def test_append_new_subnetwork( self, want_logits, want_complexity_regularization, want_loss=None, want_adanet_loss=None, want_mixture_weight_vars=None, adanet_lambda=0., adanet_beta=0., ensemble_spec_fn=lambda: None, use_bias=False, use_logits_last_layer=False, mixture_weight_type=MixtureWeightType.MATRIX, mixture_weight_initializer=tf.zeros_initializer(), warm_start_mixture_weights=True, subnetwork_builder_class=_Builder, mode=tf.estimator.ModeKeys.TRAIN): seed = 64 builder = _EnsembleBuilder( head=tf.contrib.estimator.binary_classification_head( loss_reduction=tf.losses.Reduction.SUM), mixture_weight_type=mixture_weight_type, mixture_weight_initializer=mixture_weight_initializer, warm_start_mixture_weights=warm_start_mixture_weights, adanet_lambda=adanet_lambda, adanet_beta=adanet_beta, use_bias=use_bias) features = {"x": tf.constant([[1.], [2.]])} labels = tf.constant([0, 1]) def _subnetwork_train_op_fn(loss, var_list): self.assertEqual(2, len(var_list)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) def _mixture_weights_train_op_fn(loss, var_list): self.assertEqual(want_mixture_weight_vars, len(var_list)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) ensemble_spec = builder.append_new_subnetwork( ensemble_spec=ensemble_spec_fn(), subnetwork_builder=subnetwork_builder_class( _subnetwork_train_op_fn, _mixture_weights_train_op_fn, use_logits_last_layer, seed), summary=tf.summary, features=features, iteration_step=tf.train.get_or_create_global_step(), labels=labels, mode=mode) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) if mode == tf.estimator.ModeKeys.PREDICT: self.assertAllClose(want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) self.assertIsNone(ensemble_spec.loss) self.assertIsNone(ensemble_spec.adanet_loss) self.assertIsNone(ensemble_spec.train_op) self.assertIsNotNone(ensemble_spec.export_outputs) return # Verify that train_op works, previous loss should be greater than loss # after a train op. loss = sess.run(ensemble_spec.loss) for _ in range(3): sess.run(ensemble_spec.train_op) self.assertGreater(loss, sess.run(ensemble_spec.loss)) self.assertAllClose(want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) # Bias should learn a non-zero value when used. if use_bias: self.assertNotEqual(0., sess.run(ensemble_spec.ensemble.bias)) else: self.assertAlmostEqual(0., sess.run(ensemble_spec.ensemble.bias)) self.assertAlmostEqual( want_complexity_regularization, sess.run(ensemble_spec.complexity_regularization), places=3) self.assertAlmostEqual(want_loss, sess.run(ensemble_spec.loss), places=3) self.assertAlmostEqual(want_adanet_loss, sess.run(ensemble_spec.adanet_loss), places=3)