コード例 #1
0
def test_input_shape_and_dtype(
    rnn_model: RNNModel,
    batch_prev_tkids: torch.Tensor,
):
    r"""Input must be long tensor."""

    try:
        rnn_model = rnn_model.eval()
        rnn_model.pred(batch_prev_tkids)
    except Exception:
        assert False
コード例 #2
0
def test_value_range(
    rnn_model: RNNModel,
    batch_prev_tkids: torch.Tensor,
):
    r"""Return values are probabilities."""
    rnn_model = rnn_model.eval()
    out = rnn_model.pred(batch_prev_tkids)

    # Probabilities are values within range [0, 1].
    assert torch.all(0 <= out).item()
    assert torch.all(out <= 1).item()

    # Sum of the probabilities equals to 1.
    accum_out = out.sum(dim=-1)
    assert torch.allclose(accum_out, torch.ones_like(accum_out))
コード例 #3
0
def test_return_shape_and_dtype(
    rnn_model: RNNModel,
    batch_prev_tkids: torch.Tensor,
    batch_next_tkids: torch.Tensor,
):
    r"""Return float tensor with 0 dimension."""
    rnn_model = rnn_model.eval()
    loss = rnn_model.loss_fn(
        batch_prev_tkids=batch_prev_tkids,
        batch_next_tkids=batch_next_tkids,
    )

    # 0 dimension tensor.
    assert loss.shape == torch.Size([])
    # Return float tensor.
    assert loss.dtype == torch.float
コード例 #4
0
def test_input_shape_and_dtype(
    rnn_model: RNNModel,
    batch_prev_tkids: torch.Tensor,
    batch_next_tkids: torch.Tensor,
):
    r"""Input tensors must be long tensors and have the same shape.

    Same shape is required since we are using teacher forcing.
    """
    try:
        rnn_model = rnn_model.eval()
        rnn_model.loss_fn(
            batch_prev_tkids=batch_prev_tkids,
            batch_next_tkids=batch_next_tkids,
        )
    except Exception:
        assert False
コード例 #5
0
def test_return_shape_and_dtype(
    rnn_model: RNNModel,
    batch_prev_tkids: torch.Tensor,
):
    r"""Return float tensor with correct shape."""
    rnn_model = rnn_model.eval()
    out = rnn_model.pred(batch_prev_tkids)

    # Output float tensor.
    assert out.dtype == torch.float

    # Input shape: (B, S).
    # Output shape: (B, S, V).
    assert out.shape == (
        batch_prev_tkids.shape[0],
        batch_prev_tkids.shape[1],
        rnn_model.emb.num_embeddings,
    )