def build_subnetwork(self, features, logits_dimension, training, iteration_step, summary, previous_ensemble=None): assert features is not None assert training is not None assert iteration_step is not None assert summary is not None # Trainable variables collection should always be empty when # build_subnetwork is called. assert not tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES) # Subnetworks get iteration steps instead of global steps. step_name = "subnetwork_test/iteration_step" assert step_name == tf_compat.tensor_name( tf_compat.v1.train.get_global_step()) assert step_name == tf_compat.tensor_name(train.get_global_step()) assert step_name == tf_compat.tensor_name( training_util.get_global_step()) assert step_name == tf_compat.tensor_name( tf_v1.train.get_global_step()) assert step_name == tf_compat.tensor_name( tf_compat.v1.train.get_or_create_global_step()) assert step_name == tf_compat.tensor_name( train.get_or_create_global_step()) assert step_name == tf_compat.tensor_name( training_util.get_or_create_global_step()) assert step_name == tf_compat.tensor_name( tf_v1.train.get_or_create_global_step()) # Subnetworks get scoped summaries. assert "fake_scalar" == tf_compat.v1.summary.scalar("scalar", 1.) assert "fake_image" == tf_compat.v1.summary.image("image", 1.) assert "fake_histogram" == tf_compat.v1.summary.histogram( "histogram", 1.) assert "fake_audio" == tf_compat.v1.summary.audio("audio", 1., 1.) last_layer = tu.dummy_tensor(shape=(2, 3)) def logits_fn(logits_dim): return tf_compat.v1.layers.dense( last_layer, units=logits_dim, kernel_initializer=tf_compat.v1.glorot_uniform_initializer( seed=self._seed)) if self._multi_head: logits = { "head1": logits_fn(logits_dimension / 2), "head2": logits_fn(logits_dimension / 2) } last_layer = {"head1": last_layer, "head2": last_layer} else: logits = logits_fn(logits_dimension) return Subnetwork( last_layer=logits if self._use_logits_last_layer else last_layer, logits=logits, complexity=2, persisted_tensors={})
def _load_variable_from_model_dir(self, var): return tf.train.load_variable(self._model_dir, tf_compat.tensor_name(var))