コード例 #1
0
def test_expand_dims_for_broadcast():
    with clean_session():
        arr = tf.constant([
            [
                [1, 2, 3],
                [4, 5, 6],
                [4, 5, 6],
            ],
            [
                [1, 2, 3],
                [4, 5, 6],
                [4, 5, 6],
            ],
        ],
                          dtype=tf.float32)
        weights = tf.constant([1, 2], dtype=tf.float32)

        assert arr.get_shape().as_list() == [2, 3, 3]
        assert weights.get_shape().as_list() == [2]

        new_weights = expand_dims_for_broadcast(weights, arr)
        assert new_weights.eval().shape == (2, 1, 1)

        bad_weights = tf.constant([1, 2, 3], dtype=tf.float32)
        bad_new_weights = expand_dims_for_broadcast(bad_weights, arr)

        with pytest.raises(InvalidArgumentError):
            bad_new_weights.eval()
コード例 #2
0
ファイル: seq_batch.py プロジェクト: lvyiwei1/StylePTB
def embed(sequence_batch, embeds):
    mask = sequence_batch.mask
    embedded_values = tf.gather(embeds, sequence_batch.values)
    embedded_values = tf.verify_tensor_all_finite(embedded_values,
                                                  'embedded_values')

    # set all pad embeddings to zero
    broadcasted_mask = expand_dims_for_broadcast(mask, embedded_values)
    embedded_values *= broadcasted_mask

    return SequenceBatch(embedded_values, mask)
コード例 #3
0
ファイル: seq_batch.py プロジェクト: lvyiwei1/StylePTB
def weighted_sum(seq_batch, weights):
    """Compute the weighted sum of each sequence in a SequenceBatch.

    Args:
        seq_batch (SequenceBatch): a SequenceBatch.
        weights (Tensor): a Tensor of shape (batch_size, seq_length). Determines the weights. Weights outside the
            seq_batch's mask are ignored.

    Returns:
        Tensor: of shape (batch_size, :, ..., :)
    """
    values, mask = seq_batch.values, seq_batch.mask
    weights = weights * mask  # ignore weights outside the mask
    weights = expand_dims_for_broadcast(weights, values)
    weighted_array = values * weights  # (batch_size, seq_length, X)
    return tf.reduce_sum(weighted_array, 1)  # (batch_size, X)
コード例 #4
0
def test_broadcast():
    with clean_session():
        values = tf.constant([[
            [1, 2],
            [1, 2],
        ], [
            [1, 2],
            [3, 4],
        ], [
            [5, 6],
            [7, 8],
        ]],
                             dtype=tf.float32)

        mask = tf.constant([
            [1, 0],
            [1, 1],
            [0, 1],
        ], dtype=tf.float32)

        correct = np.array([[
            [1, 1],
            [0, 0],
        ], [
            [1, 1],
            [1, 1],
        ], [
            [0, 0],
            [1, 1],
        ]],
                           dtype=np.float32)

        assert values.get_shape().as_list() == [3, 2, 2]
        assert mask.get_shape().as_list() == [3, 2]

        mask = expand_dims_for_broadcast(mask, values)
        assert mask.get_shape().as_list() == [3, 2, 1]

        mask = broadcast(mask, values)
        assert mask.get_shape().as_list() == [3, 2, 2]

        mask_val = mask.eval()

        assert_array_equal(mask_val, correct)
コード例 #5
0
ファイル: seq_batch.py プロジェクト: lvyiwei1/StylePTB
def change_pad_value(values, mask, pad_val):
    """Given a set of values and a pad mask, change the value of all pad entries.

    Args:
        values (Tensor): of shape [batch_size, seq_length, :, ..., :].
        mask (Tensor): binary float tensor of shape [batch_size, seq_length]
        pad_val (float): value to set all pad entries to

    Returns:
        Tensor: a new Tensor of same shape as values
    """
    # broadcast the mask to match shape of values
    mask = expand_dims_for_broadcast(
        mask, values)  # (batch_size, seq_length, 1, ..., 1)
    mask = broadcast(mask, values)
    mask = tf.cast(mask, tf.bool)  # cast to bool

    # broadcast val
    broadcast_val = pad_val * tf.ones(tf.shape(values))

    new_values = tf.select(mask, values, broadcast_val)
    return new_values