def test_seq_length(self): tokens = '<unk> a b c'.split() unk = '<unk>' vocab = VocabExample(tokens, unk) sequences = [ 'a b a b c'.split(), # more than length 4 'a b'.split(), ['b'], ['c'], ] indices = np.array([ [2, 1, 2, 3], [0, 0, 1, 2], [0, 0, 0, 2], [0, 0, 0, 3], ], dtype=np.int32) mask = np.array([ [1, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 1], ], dtype=np.float32) with clean_session(): model = FeedSequenceBatch(align='right', seq_length=4) test_feed = model.inputs_to_feed_dict(sequences, vocab) correct = {model.values: indices, model.mask: mask} assert_array_collections_equal(correct, test_feed) indices = tf.identity(model.values) mask = tf.identity(model.mask) assert indices.get_shape().as_list() == [None, 4] assert mask.get_shape().as_list() == [None, 4]
def __init__(self, input_scores): """Align a candidate with elements of the input, and define its score to be the summed score of aligned inputs. Args: input_scores (Tensor): of shape (batch_size, input_length) """ input_scores_flat = tf.reshape( input_scores, shape=[-1]) # (batch_size * input_length,) self._input_length = input_scores.get_shape().as_list()[1] alignments_flat = FeedSequenceBatch( ) # (total_candidates, max_alignments) alignment_weights_flat = FeedSequenceBatch( dtype=tf.float32) # (total_candidates, max_alignments) aligned_attention_weights = embed( alignments_flat, input_scores_flat) # (total_candidates, max_alignments) scores_flat = weighted_sum(aligned_attention_weights, alignment_weights_flat.with_pad_value( 0).values) # (total_candidates,) unflatten = FeedSequenceBatch() # (batch_size, num_candidates) scores = embed(unflatten, scores_flat).with_pad_value( 0) # (batch_size, num_candidates) self._alignments_flat = alignments_flat self._alignment_weights_flat = alignment_weights_flat self._unflatten = unflatten self._scores = scores
def test_no_sequences(self): vocab = SimpleVocab('a b c'.split()) sequences = [] with clean_session(): model = FeedSequenceBatch() indices = tf.identity(model.values) mask = tf.identity(model.mask) indices_val, mask_val = model.compute([indices, mask], sequences, vocab) assert indices_val.shape == mask_val.shape == (0, 0)
def test_right_align(self, inputs): indices = np.array([ [1, 1, 2, 2, 3], [0, 0, 0, 1, 2], [0, 0, 0, 0, 2], [0, 0, 0, 0, 3], ], dtype=np.int32) mask = np.array([ [1, 1, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 1], ], dtype=np.float32) with clean_session(): model = FeedSequenceBatch(align='right') correct = {model.values: indices, model.mask: mask} args, kwargs = inputs test = model.inputs_to_feed_dict(*args, **kwargs) assert_array_collections_equal(correct, test)
def __init__(self, token_embeds, align='left', seq_length=None, name='SequenceEmbedder'): """Create a SequenceEmbeddings object. Args: token_embeds (Tensor): a Tensor of shape (token_vocab_size, token_dim) align (str): see FeedSequenceBatch seq_length (int): see FeedSequenceBatch """ with tf.name_scope(name): sequence_batch = FeedSequenceBatch(align=align, seq_length=seq_length) # (sequence_vocab_size, seq_length) embedded_sequence_batch = embed(sequence_batch, token_embeds) embeds = self.embed_sequences(embedded_sequence_batch) self._sequence_batch = sequence_batch self._embedded_sequence_batch = embedded_sequence_batch self._embeds = embeds
def __init__(self, query, cand_embeds, project_query=False): """Create a CandidateScorer. Args: query (Tensor): of shape (batch_size, query_dim) cand_embeds (Tensor): of shape (cand_vocab_size, cand_dim) project_query (bool): whether to project the query tensor to match the dimension of the cand_embeds """ with tf.name_scope("CandidateScorer"): cand_batch = FeedSequenceBatch() embedded_cand_batch = embed(cand_batch, cand_embeds) # (batch_size, num_candidates, cand_dim) attention = Attention(embedded_cand_batch, query, project_query=project_query) self._attention = attention self._cand_batch = cand_batch self._scores = SequenceBatch(attention.logits, cand_batch.mask) self._probs = SequenceBatch(attention.probs, cand_batch.mask)
def model(self): return FeedSequenceBatch(align='left')