def fully_connecteds( draw, return_kwargs: bool = False ) -> Union[ st.SearchStrategy[FullyConnected], st.SearchStrategy[Tuple[FullyConnected, Dict]], ]: """Returns a SearchStrategy for FullyConnected.""" kwargs = draw(fully_connected_kwargs()) if not return_kwargs: return FullyConnected(**kwargs) return FullyConnected(**kwargs), kwargs
def test_fully_connected_raises_value_error_negative_num_hidden_layers( kwargs: Dict, num_hidden_layers: int ) -> None: """Ensures ValueError raised when num_hidden_layers < 0.""" kwargs["num_hidden_layers"] = num_hidden_layers with pytest.raises(ValueError): FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_hidden_activation_fn_not_none( kwargs: Dict, hidden_activation_fn: torch.nn.Module, ) -> None: """Ensures ValueError raised when no hidden layers and no act fn.""" kwargs["num_hidden_layers"] = 0 kwargs["hidden_activation_fn"] = hidden_activation_fn with pytest.raises(ValueError): FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_hidden_size_not_none( kwargs: Dict, hidden_size: int ) -> None: """Ensures ValueError raised when no hidden layers and not hidden_size.""" kwargs["num_hidden_layers"] = 0 kwargs["hidden_size"] = hidden_size with pytest.raises(ValueError): FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_dropout_negative_or_greater_than_1( kwargs: Dict, dropout: float ) -> None: """Ensures ValueError raised when dropout is < 0 or > 1.""" kwargs["dropout"] = dropout assume(kwargs["num_hidden_layers"] != 0) with pytest.raises(ValueError): FullyConnected(**kwargs)
def test_fully_connected_raises_value_error_zero_hidden_dropout_not_None( kwargs: Dict, dropout: float ) -> None: """Ensures ValueError raised when hidden_size=0 and dropout is not None.""" kwargs["num_hidden_layers"] = 0 kwargs["dropout"] = dropout with pytest.raises(ValueError): FullyConnected(**kwargs)
def build( fully_connected_cfg: fully_connected_pb2.FullyConnected, input_features: int, output_features: int, ) -> FullyConnected: """Returns a :py:class:`.FullyConnected` based on the config. Args: fully_connected_cfg: A ``FullyConnected`` protobuf object containing the config for the desired :py:class:`torch.nn.Module`. input_features: The number of features for the input. output_features: The number of output features. Returns: A :py:class:`torch.nn.Module` based on the config. Example: >>> from google.protobuf import text_format >>> cfg_text = ''' ... num_hidden_layers: 2; ... hidden_size: 64; ... activation { ... relu { } ... } ... ''' >>> cfg = text_format.Merge( ... cfg_text, ... fully_connected_pb2.FullyConnected() ... ) >>> build(cfg, input_features=32, output_features=16) FullyConnected( (fully_connected): Sequential( (0): Linear(in_features=32, out_features=64, bias=True) (1): ReLU() (2): Linear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): Linear(in_features=64, out_features=16, bias=True) ) ) """ activation = build_activation(fully_connected_cfg.activation) if isinstance(activation, torch.nn.Identity): activation = None hidden_size = None if fully_connected_cfg.hidden_size > 0: hidden_size = fully_connected_cfg.hidden_size return FullyConnected( in_features=input_features, out_features=output_features, num_hidden_layers=fully_connected_cfg.num_hidden_layers, hidden_size=hidden_size, hidden_activation_fn=activation, )
def fully_connecteds( draw, return_kwargs: bool = False ) -> Union[ st.SearchStrategy[FullyConnected], st.SearchStrategy[Tuple[FullyConnected, Dict]], ]: """Returns a SearchStrategy for FullyConnected.""" kwargs = {} kwargs["in_features"] = draw(st.integers(1, 32)) kwargs["out_features"] = draw(st.integers(1, 32)) kwargs["num_hidden_layers"] = draw(st.integers(0, 8)) if kwargs["num_hidden_layers"] == 0: kwargs["hidden_size"] = None kwargs["hidden_activation_fn"] = None else: kwargs["hidden_size"] = draw(st.integers(1, 32)) kwargs["hidden_activation_fn"] = draw( st.sampled_from([torch.nn.ReLU(), torch.nn.Tanh()]) ) if not return_kwargs: return FullyConnected(**kwargs) return FullyConnected(**kwargs), kwargs
def test_fully_connected_forward_returns_correct_size( kwargs: Dict, tensor: torch.Tensor ) -> None: # create new FullyConnected that accepts in_features sized input kwargs["in_features"] = tensor.size()[-1] fully_connected = FullyConnected(**kwargs) max_seq_len, batch_size, *_ = tensor.size() in_seq_lens = torch.randint( low=1, high=max_seq_len + 1, size=[batch_size], dtype=torch.int32, requires_grad=False, ) out, _ = fully_connected((tensor, in_seq_lens)) in_size = tensor.size() out_size = out.size() assert len(in_size) == len(out_size) assert in_size[:-1] == out_size[:-1] assert out_size[-1] == kwargs["out_features"]
def fully_connected_module_match_cfg( fully_connected: FullyConnected, fully_connected_cfg: fully_connected_pb2.FullyConnected, input_features: int, output_features: int, ) -> None: """Ensures ``FullyConnected`` module matches protobuf configuration.""" fully_connected = fully_connected.fully_connected # get torch module # if no hidden layers then test that the module is Linear with corret # sizes, ignore activation if fully_connected_cfg.num_hidden_layers == 0: assert isinstance(fully_connected, torch.nn.Linear) assert fully_connected.in_features == input_features assert fully_connected.out_features == output_features assert not fully_connected.HasField("dropout") return # otherwise it will be a Sequential of layers assert isinstance(fully_connected, torch.nn.Sequential) # expected configuration of each layer in Sequential depends on whether # both/either of {activation, dropout} are present. act_fn_is_none = fully_connected_cfg.activation.HasField("identity") dropout_is_none = not fully_connected_cfg.HasField("dropout") dropout_is_none = dropout_is_none or fully_connected_cfg.dropout.value == 0 if act_fn_is_none: expected_len = fully_connected_cfg.num_hidden_layers + 1 else: expected_len = 2 * fully_connected_cfg.num_hidden_layers + 1 if not dropout_is_none: expected_len += fully_connected_cfg.num_hidden_layers assert len(fully_connected) == expected_len # Now check that the linear/activation_fn/dropout layers appear in the # expected order. We set the ``module_idx`` and then check for the # following condition: # if module_idx % total_types == <module_type>_idx: # assert isinstance(module, <module_type>) linear_idx = 0 # in all cases activation_idx = -1 # infeasible value as default dropout_idx = -1 if act_fn_is_none and dropout_is_none: total_types = 1 # (linear layers only) elif not act_fn_is_none and dropout_is_none: total_types = 2 # (linear and activation) activation_idx = 1 elif act_fn_is_none and not dropout_is_none: total_types = 2 dropout_idx = 1 elif not act_fn_is_none and not dropout_is_none: total_types = 3 activation_idx = 1 dropout_idx = 2 for module_idx, module in enumerate(fully_connected): if module_idx % total_types == linear_idx: assert isinstance(module, torch.nn.Linear) assert module.in_features == input_features if module_idx == len(fully_connected) - 1: assert module.out_features == output_features else: assert module.out_features == fully_connected_cfg.hidden_size input_features = fully_connected_cfg.hidden_size elif module_idx % total_types == activation_idx: activation_match_cfg(module, fully_connected_cfg.activation) elif module_idx % total_types == dropout_idx: assert isinstance(module, torch.nn.Dropout) assert abs(module.p - fully_connected_cfg.dropout.value) < 1e-8 else: raise ValueError("Check module_idx and total_types assignment. It " "**should not** be possible to hit this branch!")