def test_expand_dims_for_broadcast(): with clean_session(): arr = tf.constant([ [ [1, 2, 3], [4, 5, 6], [4, 5, 6], ], [ [1, 2, 3], [4, 5, 6], [4, 5, 6], ], ], dtype=tf.float32) weights = tf.constant([1, 2], dtype=tf.float32) assert arr.get_shape().as_list() == [2, 3, 3] assert weights.get_shape().as_list() == [2] new_weights = expand_dims_for_broadcast(weights, arr) assert new_weights.eval().shape == (2, 1, 1) bad_weights = tf.constant([1, 2, 3], dtype=tf.float32) bad_new_weights = expand_dims_for_broadcast(bad_weights, arr) with pytest.raises(InvalidArgumentError): bad_new_weights.eval()
def embed(sequence_batch, embeds): mask = sequence_batch.mask embedded_values = tf.gather(embeds, sequence_batch.values) embedded_values = tf.verify_tensor_all_finite(embedded_values, 'embedded_values') # set all pad embeddings to zero broadcasted_mask = expand_dims_for_broadcast(mask, embedded_values) embedded_values *= broadcasted_mask return SequenceBatch(embedded_values, mask)
def weighted_sum(seq_batch, weights): """Compute the weighted sum of each sequence in a SequenceBatch. Args: seq_batch (SequenceBatch): a SequenceBatch. weights (Tensor): a Tensor of shape (batch_size, seq_length). Determines the weights. Weights outside the seq_batch's mask are ignored. Returns: Tensor: of shape (batch_size, :, ..., :) """ values, mask = seq_batch.values, seq_batch.mask weights = weights * mask # ignore weights outside the mask weights = expand_dims_for_broadcast(weights, values) weighted_array = values * weights # (batch_size, seq_length, X) return tf.reduce_sum(weighted_array, 1) # (batch_size, X)
def test_broadcast(): with clean_session(): values = tf.constant([[ [1, 2], [1, 2], ], [ [1, 2], [3, 4], ], [ [5, 6], [7, 8], ]], dtype=tf.float32) mask = tf.constant([ [1, 0], [1, 1], [0, 1], ], dtype=tf.float32) correct = np.array([[ [1, 1], [0, 0], ], [ [1, 1], [1, 1], ], [ [0, 0], [1, 1], ]], dtype=np.float32) assert values.get_shape().as_list() == [3, 2, 2] assert mask.get_shape().as_list() == [3, 2] mask = expand_dims_for_broadcast(mask, values) assert mask.get_shape().as_list() == [3, 2, 1] mask = broadcast(mask, values) assert mask.get_shape().as_list() == [3, 2, 2] mask_val = mask.eval() assert_array_equal(mask_val, correct)
def change_pad_value(values, mask, pad_val): """Given a set of values and a pad mask, change the value of all pad entries. Args: values (Tensor): of shape [batch_size, seq_length, :, ..., :]. mask (Tensor): binary float tensor of shape [batch_size, seq_length] pad_val (float): value to set all pad entries to Returns: Tensor: a new Tensor of same shape as values """ # broadcast the mask to match shape of values mask = expand_dims_for_broadcast( mask, values) # (batch_size, seq_length, 1, ..., 1) mask = broadcast(mask, values) mask = tf.cast(mask, tf.bool) # cast to bool # broadcast val broadcast_val = pad_val * tf.ones(tf.shape(values)) new_values = tf.select(mask, values, broadcast_val) return new_values