def build_subnetwork(self,
                         features,
                         logits_dimension,
                         training,
                         iteration_step,
                         summary,
                         previous_ensemble=None):
        assert features is not None
        assert training is not None
        assert iteration_step is not None
        assert summary is not None

        # Trainable variables collection should always be empty when
        # build_subnetwork is called.
        assert not tf_compat.v1.get_collection(
            tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)

        # Subnetworks get iteration steps instead of global steps.
        step_name = "subnetwork_test/iteration_step"
        assert step_name == tf_compat.tensor_name(
            tf_compat.v1.train.get_global_step())
        assert step_name == tf_compat.tensor_name(train.get_global_step())
        assert step_name == tf_compat.tensor_name(
            training_util.get_global_step())
        assert step_name == tf_compat.tensor_name(
            tf_v1.train.get_global_step())
        assert step_name == tf_compat.tensor_name(
            tf_compat.v1.train.get_or_create_global_step())
        assert step_name == tf_compat.tensor_name(
            train.get_or_create_global_step())
        assert step_name == tf_compat.tensor_name(
            training_util.get_or_create_global_step())
        assert step_name == tf_compat.tensor_name(
            tf_v1.train.get_or_create_global_step())

        # Subnetworks get scoped summaries.
        assert "fake_scalar" == tf_compat.v1.summary.scalar("scalar", 1.)
        assert "fake_image" == tf_compat.v1.summary.image("image", 1.)
        assert "fake_histogram" == tf_compat.v1.summary.histogram(
            "histogram", 1.)
        assert "fake_audio" == tf_compat.v1.summary.audio("audio", 1., 1.)
        last_layer = tu.dummy_tensor(shape=(2, 3))

        def logits_fn(logits_dim):
            return tf_compat.v1.layers.dense(
                last_layer,
                units=logits_dim,
                kernel_initializer=tf_compat.v1.glorot_uniform_initializer(
                    seed=self._seed))

        if self._multi_head:
            logits = {
                "head1": logits_fn(logits_dimension / 2),
                "head2": logits_fn(logits_dimension / 2)
            }
            last_layer = {"head1": last_layer, "head2": last_layer}
        else:
            logits = logits_fn(logits_dimension)

        return Subnetwork(
            last_layer=logits if self._use_logits_last_layer else last_layer,
            logits=logits,
            complexity=2,
            persisted_tensors={})
Пример #2
0
 def _load_variable_from_model_dir(self, var):
     return tf.train.load_variable(self._model_dir,
                                   tf_compat.tensor_name(var))