def ar2word2sent2doc(document, *, words, char_embeddings, word_embedding_size, context_vector_size, save_memory=True, **rd2sent2doc_hyperparams): """char2word2sent2doc model without character embeddings as parameters """ assert ex.static_rank(document) == 3 assert ex.static_rank(words) == 2 assert ex.static_rank(char_embeddings) == 2 with tf.variable_scope("char2word"): word_embeddings = ex.bidirectional_id_vector_to_embedding( tf.gather(words, ex.flatten(document)) if save_memory else words, char_embeddings, output_size=word_embedding_size, context_vector_size=context_vector_size, dynamic_length=True) return rd2sent2doc(document, word_embeddings, context_vector_size=context_vector_size, save_memory=save_memory, **rd2sent2doc_hyperparams)
def font2char2word2sent2doc(document, *, words, fonts, char_embedding_size, dropout_keep_prob, mode, **ar2word2sent2doc_hyperparams): assert ex.static_rank(document) == 3 assert ex.static_rank(words) == 2 assert ex.static_rank(fonts) == 3 return ar2word2sent2doc( document, words=words, char_embeddings=font2char(fonts, char_embedding_size=char_embedding_size), **ar2word2sent2doc_hyperparams)
def sequence_labeling_loss(logits, labels, sequence_length=None): assert ex.static_rank(logits) == 3 assert ex.static_rank(labels) == 2 losses = tf.reshape( tf.nn.sparse_softmax_cross_entropy_with_logits( tf.reshape(logits, [-1, ex.static_shape(logits)[-1]]), tf.reshape(labels, [-1])), [-1, *ex.static_shape(labels)[1:]]) if sequence_length == None: return tf.reduce_mean(losses) mask = tf.sequence_mask(sequence_length, dtype=losses.dtype) return tf.reduce_sum(losses * mask) / tf.reduce_sum(mask)
def test_font2char(): assert ( ex.static_rank(font2char.font2char( tf.zeros([64, 224, 224]), nums_of_channels=[32] * 4, nums_of_attention_channels=[32] * 3)) == 2)
def font2char2word2sent2doc(document, *, words, fonts, dropout_keep_prob, nums_of_cnn_channels, nums_of_attention_cnn_channels, mode, **ar2word2sent2doc_hyperparams): assert ex.static_rank(document) == 3 assert ex.static_rank(words) == 2 assert ex.static_rank(fonts) == 3 return ar2word2sent2doc( document, words=words, char_embeddings=font2char( fonts, nums_of_channels=nums_of_cnn_channels, nums_of_attention_channels=nums_of_attention_cnn_channels), **ar2word2sent2doc_hyperparams)
def batch_linear(h, output_size): assert ex.static_rank(h) == 3 shape = ex.static_shape(h) return (tf.batch_matmul( h, tf.tile(tf.expand_dims(ex.variable([shape[2], output_size]), 0), [shape[0], 1, 1])) + ex.variable([output_size]))
def _cnn(h, nums_of_channels): assert ex.static_rank(h) == 4 for index, num_of_channels in enumerate(nums_of_channels): h = tf.contrib.slim.conv2d( h, num_of_channels, 3, scope='conv{}'.format(index)) h = tf.contrib.slim.max_pool2d(h, 2, 2, scope='pool{}'.format(index)) return h
def _attend_to_image(images, nums_of_channels): assert ex.static_rank(images) == 4 attentions = _calculate_attention(images, nums_of_channels) collections.add_attention(attentions) return tf.transpose( tf.transpose(images) * tf.transpose(attentions))
def char2word2sent2doc(document, *, words, char_space_size, char_embedding_size, **ar2word2sent2doc_hyperparams): """ The argument `document` is in the shape of (#examples, #sentences per document, #words per sentence). """ assert ex.static_rank(document) == 3 assert ex.static_rank(words) == 2 with tf.variable_scope("char_embeddings"): char_embeddings = ex.embeddings(id_space_size=char_space_size, embedding_size=char_embedding_size, name="char_embeddings") return ar2word2sent2doc(document, words=words, char_embeddings=char_embeddings, **ar2word2sent2doc_hyperparams)
def font2char(font, *, nums_of_channels, nums_of_attention_channels): assert ex.static_rank(font) == 3 h = tf.contrib.slim.flatten( _cnn(_attend_to_image(tf.expand_dims(font, -1), nums_of_attention_channels), nums_of_channels)) ex.summary.image(tf.expand_dims(tf.expand_dims(h[:256], 0), 3)) return h
def rd2sent2doc(document, word_embeddings, *, sentence_embedding_size, document_embedding_size, context_vector_size, save_memory=False): """ word2sent2doc model lacking word embeddings as parameters """ assert ex.static_rank(document) == 3 assert ex.static_rank(word_embeddings) == 2 embeddings_to_embedding = functools.partial( ex.bidirectional_embeddings_to_embedding, context_vector_size=context_vector_size) with tf.variable_scope("word2sent"): # word_embeddings.shape == (#batch * #sent * #word, emb_size) # if save_memory else # (vocab_size, emb_size) sentences = _flatten_document_into_sentences(document) sentence_embeddings = _restore_document_shape( embeddings_to_embedding( (_restore_sentence_shape(word_embeddings, sentences) if save_memory else tf.gather(word_embeddings, sentences)), sequence_length=ex.id_vector_to_length(sentences), output_size=sentence_embedding_size), document) with tf.variable_scope("sent2doc"): return embeddings_to_embedding( sentence_embeddings, sequence_length=ex.id_tensor_to_length(document), output_size=document_embedding_size)
def word2sent2doc(document, *, word_space_size, word_embedding_size, **rd2sent2doc_hyperparams): assert ex.static_rank(document) == 3 with tf.variable_scope("word_embeddings"): word_embeddings = tf.gather( ex.embeddings(id_space_size=word_space_size, embedding_size=word_embedding_size, name="word_embeddings"), ex.flatten(document)) return rd2sent2doc(document, word_embeddings, save_memory=True, **rd2sent2doc_hyperparams)
def _calculate_attention(images, nums_of_channels): assert ex.static_rank(images) == 4 logits = tf.squeeze( tf.image.resize_nearest_neighbor( tf.expand_dims(tf.reduce_sum(_cnn(images, nums_of_channels), axis=3), axis=-1), tf.shape(images)[1:3]), axis=3) return tf.reshape( tf.nn.softmax(tf.reshape(logits, [tf.shape(logits)[0], -1])), tf.shape(logits))
def font2char(fonts, char_embedding_size): assert ex.static_rank(fonts) == 3 h = ex.lenet(tf.expand_dims(fonts, -1), output_size=char_embedding_size) ex.summary.image(tf.expand_dims(tf.expand_dims(h[:256], 0), 3)) return h
def test_attend_to_image(): assert ( ex.static_rank(font2char._attend_to_image( tf.zeros([64, 224, 224, 1]), nums_of_channels=[32] * 3)) == 4)