def _build_subnetwork(self, multi_head=False): last_layer = tf.Variable( tf_compat.random_normal(shape=(2, 3)), trainable=False).read_value() def new_logits(): return tf_compat.v1.layers.dense( last_layer, units=1, kernel_initializer=tf_compat.v1.glorot_uniform_initializer()) if multi_head: logits = {k: new_logits() for k in multi_head} last_layer = {k: last_layer for k in multi_head} else: logits = new_logits() return subnetwork.Subnetwork(last_layer=logits, logits=logits, complexity=2)
def dummy_tensor(shape=(), random_seed=42): """Returns a randomly initialized tensor.""" return tf.Variable(tf_compat.random_normal(shape=shape, seed=random_seed), trainable=False).read_value()