Пример #1
0
def test_broadcast():
    with clean_session():
        values = tf.constant([[
            [1, 2],
            [1, 2],
        ], [
            [1, 2],
            [3, 4],
        ], [
            [5, 6],
            [7, 8],
        ]],
                             dtype=tf.float32)

        mask = tf.constant([
            [1, 0],
            [1, 1],
            [0, 1],
        ], dtype=tf.float32)

        correct = np.array([[
            [1, 1],
            [0, 0],
        ], [
            [1, 1],
            [1, 1],
        ], [
            [0, 0],
            [1, 1],
        ]],
                           dtype=np.float32)

        assert values.get_shape().as_list() == [3, 2, 2]
        assert mask.get_shape().as_list() == [3, 2]

        mask = expand_dims_for_broadcast(mask, values)
        assert mask.get_shape().as_list() == [3, 2, 1]

        mask = broadcast(mask, values)
        assert mask.get_shape().as_list() == [3, 2, 2]

        mask_val = mask.eval()

        assert_array_equal(mask_val, correct)
Пример #2
0
def change_pad_value(values, mask, pad_val):
    """Given a set of values and a pad mask, change the value of all pad entries.

    Args:
        values (Tensor): of shape [batch_size, seq_length, :, ..., :].
        mask (Tensor): binary float tensor of shape [batch_size, seq_length]
        pad_val (float): value to set all pad entries to

    Returns:
        Tensor: a new Tensor of same shape as values
    """
    # broadcast the mask to match shape of values
    mask = expand_dims_for_broadcast(
        mask, values)  # (batch_size, seq_length, 1, ..., 1)
    mask = broadcast(mask, values)
    mask = tf.cast(mask, tf.bool)  # cast to bool

    # broadcast val
    broadcast_val = pad_val * tf.ones(tf.shape(values))

    new_values = tf.select(mask, values, broadcast_val)
    return new_values