def get_channel_embeddings(io_depth, targets, hidden_size, name="channel"): """Get separate embedding for each of the channels.""" targets_split = tf.split(targets, io_depth, axis=3) rgb_embedding_var = tf.get_variable("rgb_target_emb_%s" % name, [256 * io_depth, hidden_size]) rgb_embedding_var = tf.identity(rgb_embedding_var) rgb_embedding_var *= float(hidden_size)**0.5 channel_target_embs = [] for i in range(io_depth): # Adding the channel offsets to get the right embedding since the # embedding tensor has shape 256 * io_depth, hidden_size target_ids = tf.squeeze(targets_split[i], axis=3) + i * 256 target_embs = common_layers.gather(rgb_embedding_var, target_ids) channel_target_embs.append(target_embs) return tf.concat(channel_target_embs, axis=-1)
def bottom_simple(self, x, name, reuse): with tf.variable_scope(name, reuse=reuse): # Ensure the inputs are 3-D if len(x.get_shape()) == 4: x = tf.squeeze(x, axis=3) while len(x.get_shape()) < 3: x = tf.expand_dims(x, axis=-1) var = self._get_weights() x = common_layers.dropout_no_scaling( x, 1.0 - self._model_hparams.symbol_dropout) ret = common_layers.gather(var, x) if self._model_hparams.multiply_embedding_mode == "sqrt_depth": ret *= self._body_input_depth**0.5 ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1) return ret