def fully_connecteds(
    draw, return_kwargs: bool = False
) -> Union[
    st.SearchStrategy[FullyConnected],
    st.SearchStrategy[Tuple[FullyConnected, Dict]],
]:
    """Returns a SearchStrategy for FullyConnected."""
    kwargs = draw(fully_connected_kwargs())
    if not return_kwargs:
        return FullyConnected(**kwargs)
    return FullyConnected(**kwargs), kwargs
def test_fully_connected_raises_value_error_negative_num_hidden_layers(
    kwargs: Dict, num_hidden_layers: int
) -> None:
    """Ensures ValueError raised when num_hidden_layers < 0."""
    kwargs["num_hidden_layers"] = num_hidden_layers
    with pytest.raises(ValueError):
        FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_hidden_activation_fn_not_none(
    kwargs: Dict, hidden_activation_fn: torch.nn.Module,
) -> None:
    """Ensures ValueError raised when no hidden layers and no act fn."""
    kwargs["num_hidden_layers"] = 0
    kwargs["hidden_activation_fn"] = hidden_activation_fn
    with pytest.raises(ValueError):
        FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_hidden_size_not_none(
    kwargs: Dict, hidden_size: int
) -> None:
    """Ensures ValueError raised when no hidden layers and not hidden_size."""
    kwargs["num_hidden_layers"] = 0
    kwargs["hidden_size"] = hidden_size
    with pytest.raises(ValueError):
        FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_dropout_negative_or_greater_than_1(
    kwargs: Dict, dropout: float
) -> None:
    """Ensures ValueError raised when dropout is < 0 or > 1."""
    kwargs["dropout"] = dropout
    assume(kwargs["num_hidden_layers"] != 0)
    with pytest.raises(ValueError):
        FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_zero_hidden_dropout_not_None(
    kwargs: Dict, dropout: float
) -> None:
    """Ensures ValueError raised when hidden_size=0 and dropout is not None."""
    kwargs["num_hidden_layers"] = 0
    kwargs["dropout"] = dropout
    with pytest.raises(ValueError):
        FullyConnected(**kwargs)
def build(
    fully_connected_cfg: fully_connected_pb2.FullyConnected,
    input_features: int,
    output_features: int,
) -> FullyConnected:
    """Returns a :py:class:`.FullyConnected` based on the config.

    Args:
        fully_connected_cfg: A ``FullyConnected`` protobuf object containing
            the config for the desired :py:class:`torch.nn.Module`.

        input_features: The number of features for the input.

        output_features: The number of output features.

    Returns:
        A :py:class:`torch.nn.Module` based on the config.

    Example:
        >>> from google.protobuf import text_format
        >>> cfg_text = '''
        ... num_hidden_layers: 2;
        ... hidden_size: 64;
        ... activation {
        ...   relu { }
        ... }
        ... '''
        >>> cfg = text_format.Merge(
        ...     cfg_text,
        ...     fully_connected_pb2.FullyConnected()
        ... )
        >>> build(cfg, input_features=32, output_features=16)
        FullyConnected(
          (fully_connected): Sequential(
            (0): Linear(in_features=32, out_features=64, bias=True)
            (1): ReLU()
            (2): Linear(in_features=64, out_features=64, bias=True)
            (3): ReLU()
            (4): Linear(in_features=64, out_features=16, bias=True)
          )
        )
    """
    activation = build_activation(fully_connected_cfg.activation)
    if isinstance(activation, torch.nn.Identity):
        activation = None

    hidden_size = None
    if fully_connected_cfg.hidden_size > 0:
        hidden_size = fully_connected_cfg.hidden_size

    return FullyConnected(
        in_features=input_features,
        out_features=output_features,
        num_hidden_layers=fully_connected_cfg.num_hidden_layers,
        hidden_size=hidden_size,
        hidden_activation_fn=activation,
    )
def fully_connecteds(
    draw, return_kwargs: bool = False
) -> Union[
    st.SearchStrategy[FullyConnected],
    st.SearchStrategy[Tuple[FullyConnected, Dict]],
]:
    """Returns a SearchStrategy for FullyConnected."""
    kwargs = {}
    kwargs["in_features"] = draw(st.integers(1, 32))
    kwargs["out_features"] = draw(st.integers(1, 32))
    kwargs["num_hidden_layers"] = draw(st.integers(0, 8))
    if kwargs["num_hidden_layers"] == 0:
        kwargs["hidden_size"] = None
        kwargs["hidden_activation_fn"] = None
    else:
        kwargs["hidden_size"] = draw(st.integers(1, 32))
        kwargs["hidden_activation_fn"] = draw(
            st.sampled_from([torch.nn.ReLU(), torch.nn.Tanh()])
        )
    if not return_kwargs:
        return FullyConnected(**kwargs)
    return FullyConnected(**kwargs), kwargs
def test_fully_connected_forward_returns_correct_size(
    kwargs: Dict, tensor: torch.Tensor
) -> None:
    # create new FullyConnected that accepts in_features sized input
    kwargs["in_features"] = tensor.size()[-1]
    fully_connected = FullyConnected(**kwargs)

    max_seq_len, batch_size, *_ = tensor.size()
    in_seq_lens = torch.randint(
        low=1,
        high=max_seq_len + 1,
        size=[batch_size],
        dtype=torch.int32,
        requires_grad=False,
    )
    out, _ = fully_connected((tensor, in_seq_lens))

    in_size = tensor.size()
    out_size = out.size()

    assert len(in_size) == len(out_size)
    assert in_size[:-1] == out_size[:-1]
    assert out_size[-1] == kwargs["out_features"]
示例#10
0
def fully_connected_module_match_cfg(
    fully_connected: FullyConnected,
    fully_connected_cfg: fully_connected_pb2.FullyConnected,
    input_features: int,
    output_features: int,
) -> None:
    """Ensures ``FullyConnected`` module matches protobuf configuration."""
    fully_connected = fully_connected.fully_connected  # get torch module

    # if no hidden layers then test that the module is Linear with corret
    # sizes, ignore activation
    if fully_connected_cfg.num_hidden_layers == 0:
        assert isinstance(fully_connected, torch.nn.Linear)
        assert fully_connected.in_features == input_features
        assert fully_connected.out_features == output_features
        assert not fully_connected.HasField("dropout")
        return

    # otherwise it will be a Sequential of layers
    assert isinstance(fully_connected, torch.nn.Sequential)

    # expected configuration of each layer in Sequential depends on whether
    # both/either of {activation, dropout} are present.
    act_fn_is_none = fully_connected_cfg.activation.HasField("identity")
    dropout_is_none = not fully_connected_cfg.HasField("dropout")
    dropout_is_none = dropout_is_none or fully_connected_cfg.dropout.value == 0
    if act_fn_is_none:
        expected_len = fully_connected_cfg.num_hidden_layers + 1
    else:
        expected_len = 2 * fully_connected_cfg.num_hidden_layers + 1

    if not dropout_is_none:
        expected_len += fully_connected_cfg.num_hidden_layers

    assert len(fully_connected) == expected_len

    # Now check that the linear/activation_fn/dropout layers appear in the
    # expected order. We set the ``module_idx`` and then check for the
    # following condition:
    # if module_idx % total_types == <module_type>_idx:
    #     assert isinstance(module, <module_type>)
    linear_idx = 0  # in all cases
    activation_idx = -1  # infeasible value as default
    dropout_idx = -1
    if act_fn_is_none and dropout_is_none:
        total_types = 1  # (linear layers only)
    elif not act_fn_is_none and dropout_is_none:
        total_types = 2  # (linear and activation)
        activation_idx = 1
    elif act_fn_is_none and not dropout_is_none:
        total_types = 2
        dropout_idx = 1
    elif not act_fn_is_none and not dropout_is_none:
        total_types = 3
        activation_idx = 1
        dropout_idx = 2

    for module_idx, module in enumerate(fully_connected):
        if module_idx % total_types == linear_idx:
            assert isinstance(module, torch.nn.Linear)
            assert module.in_features == input_features
            if module_idx == len(fully_connected) - 1:
                assert module.out_features == output_features
            else:
                assert module.out_features == fully_connected_cfg.hidden_size
            input_features = fully_connected_cfg.hidden_size
        elif module_idx % total_types == activation_idx:
            activation_match_cfg(module, fully_connected_cfg.activation)
        elif module_idx % total_types == dropout_idx:
            assert isinstance(module, torch.nn.Dropout)
            assert abs(module.p - fully_connected_cfg.dropout.value) < 1e-8
        else:
            raise ValueError("Check module_idx and total_types assignment. It "
                             "**should not** be possible to hit this branch!")