def test_broadcast(): with clean_session(): values = tf.constant([[ [1, 2], [1, 2], ], [ [1, 2], [3, 4], ], [ [5, 6], [7, 8], ]], dtype=tf.float32) mask = tf.constant([ [1, 0], [1, 1], [0, 1], ], dtype=tf.float32) correct = np.array([[ [1, 1], [0, 0], ], [ [1, 1], [1, 1], ], [ [0, 0], [1, 1], ]], dtype=np.float32) assert values.get_shape().as_list() == [3, 2, 2] assert mask.get_shape().as_list() == [3, 2] mask = expand_dims_for_broadcast(mask, values) assert mask.get_shape().as_list() == [3, 2, 1] mask = broadcast(mask, values) assert mask.get_shape().as_list() == [3, 2, 2] mask_val = mask.eval() assert_array_equal(mask_val, correct)
def change_pad_value(values, mask, pad_val): """Given a set of values and a pad mask, change the value of all pad entries. Args: values (Tensor): of shape [batch_size, seq_length, :, ..., :]. mask (Tensor): binary float tensor of shape [batch_size, seq_length] pad_val (float): value to set all pad entries to Returns: Tensor: a new Tensor of same shape as values """ # broadcast the mask to match shape of values mask = expand_dims_for_broadcast( mask, values) # (batch_size, seq_length, 1, ..., 1) mask = broadcast(mask, values) mask = tf.cast(mask, tf.bool) # cast to bool # broadcast val broadcast_val = pad_val * tf.ones(tf.shape(values)) new_values = tf.select(mask, values, broadcast_val) return new_values