def _build_shared_graph(self): with tf.variable_scope(self._scope, reuse=self._reuse): _encoder = MultiLayerFC(l2_reg=self._l2_reg, in_tensor=self._in_tensor, dims=self._dims[1:], scope='encoder', dropout_in=self._dropout, dropout_mid=self._dropout, reuse=self._reuse) _decoder = MultiLayerFC(l2_reg=self._l2_reg, in_tensor=_encoder.get_outputs()[0], dims=self._dims[::-1][1:], scope='decoder', relu_in=True, dropout_in=self._dropout, relu_mid=True, dropout_mid=self._dropout, relu_out=True, dropout_out=self._dropout, reuse=self._reuse) self._outputs += _encoder.get_outputs() self._loss = _encoder.get_loss() + _decoder.get_loss() self._loss += self._l2_reconst * tf.nn.l2_loss( _decoder.get_outputs()[0] - self._in_tensor)
def _build_shared_graph(self): with tf.variable_scope(self._scope, reuse=self._reuse): self._embedding = tf.get_variable('embedding', dtype=tf.float32, shape=self._shape, trainable=False, initializer=self._initializer) self._flag = tf.get_variable('flag', dtype=tf.bool, shape=[self._shape[0]], trainable=False, initializer=tf.constant_initializer( value=False, dtype=tf.bool)) unique_ids, _ = tf.unique(self._ids) with tf.control_dependencies([ tf.scatter_update(self._flag, unique_ids, tf.ones_like(unique_ids, dtype=tf.bool)) ]): trans_embedding = MultiLayerFC( in_tensor=tf.nn.embedding_lookup(self._embedding, self._ids), dims=self._mlp_dims, batch_norm=True, scope=self._scope + '/MLP', train=self._train, reuse=self._reuse, l2_reg=self._l2_reg, relu_out=True) self._outputs += trans_embedding.get_outputs() self._loss += trans_embedding.get_loss() update_ids = tf.reshape(tf.where(self._flag), [-1]) update_embedding = MultiLayerFC(in_tensor=tf.nn.embedding_lookup( self._embedding, update_ids), dims=self._mlp_dims, batch_norm=True, scope=self._scope + '/MLP', train=False, reuse=True, l2_reg=self._l2_reg, relu_out=True) self._update_node = tf.scatter_update( self._embedding, update_ids, update_embedding.get_outputs()[0]) self._clear_flag = tf.scatter_update( self._flag, update_ids, tf.zeros_like(update_ids, dtype=tf.bool))