Exemplo n.º 1
0
    def test_parameters_are_leaf_tensors(self):
        # Checks that WeightDrop parameters are always leaf tensors.

        # Case 1: After initialization
        weight_dropped_linear = DropConnect(torch.nn.Linear(10, 10),
                                            parameter_regex="weight",
                                            dropout=0.9)
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())

        # Case 2: When in training mode
        weight_dropped_linear.train()
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())

        # Case 3: After forward
        input_tensor = torch.ones(10, dtype=torch.float32)
        weight_dropped_linear(input_tensor)
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())

        # Case 4: After reset()
        weight_dropped_linear.reset()
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())

        # Case 5: When in eval mode
        weight_dropped_linear.eval()
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())
Exemplo n.º 2
0
    def test_linear_outputs(self):
        # Check that weights are (probably) being dropped out properly. There's an extremely small
        # chance (p < 1e-86) that this test fails.
        input_tensor = torch.ones(10, dtype=torch.float32)
        dropped_linear = DropConnect(torch.nn.Linear(10, 10), parameter_regex="weight", dropout=0.9)

        # Check that outputs differ if module is in training mode
        dropped_linear.train()
        output_a = dropped_linear(input_tensor)
        output_b = dropped_linear(input_tensor)
        assert not torch.allclose(output_a, output_b)

        # Check that outputs are the same if module is in eval mode
        dropped_linear.eval()
        output_a = dropped_linear(input_tensor)
        output_b = dropped_linear(input_tensor)
        assert torch.allclose(output_a, output_b)
Exemplo n.º 3
0
    def test_lstm_outputs(self):
        # Check that lstm weights are (probably) being dropped out properly. There's an extremely
        # small chance (p < 1e-86) that this test fails.
        input_tensor = torch.ones(1, 2, 10, dtype=torch.float32)  # shape: (batch, seq_length, dim)
        lstm = torch.nn.LSTM(input_size=10, hidden_size=10, batch_first=True)
        dropped_lstm = DropConnect(module=lstm, parameter_regex="weight_hh", dropout=0.9)

        # Check that outputs differ if module is in training mode. Since only hidden-to-hidden
        # weights are masked, the first outputs should be the same.
        dropped_lstm.train()
        output_a, _ = dropped_lstm(input_tensor)
        output_b, _ = dropped_lstm(input_tensor)
        assert torch.allclose(output_a[:, 0, :], output_b[:, 0, :])
        assert not torch.allclose(output_a[:, 1, :], output_b[:, 1, :])

        # Check that outputs are the same if module is in eval mode
        dropped_lstm.eval()
        output_a, _ = dropped_lstm(input_tensor)
        output_b, _ = dropped_lstm(input_tensor)
        assert torch.allclose(output_a, output_b)
Exemplo n.º 4
0
    def test_parameters_are_leaf_tensors(self):
        # Checks that WeightDrop parameters are always leaf tensors.

        _in_dim = 10
        _out_dim = 10
        _n_params = 2
        _n_weights = 1

        def _assert_sgd_states(sgd, has_grad=True):
            sgd_params = sgd.param_groups[0]["params"]
            assert all(parameter.is_leaf for parameter in sgd_params)
            # Check the states of the gradients
            assert all(has_grad != (parameter.grad is None)
                       for parameter in sgd_params)
            # The number of the parameters should stay the same otherwise fine-tuning won't work.
            assert len(sgd_params) == _n_params
            weights = [
                p for p in sgd_params
                if p.shape == torch.Size([_in_dim, _out_dim])
            ]
            assert len(weights) == _n_weights

        # Case 1: After initialization
        weight_dropped_linear = DropConnect(torch.nn.Linear(_in_dim, _out_dim),
                                            parameter_regex="weight",
                                            dropout=0.9)
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())

        # Case 2: When in training mode
        weight_dropped_linear.train()
        assert all(parameter.is_leaf
                   for parameter in weight_dropped_linear.parameters())

        # Case 3: After forward
        input_tensor = torch.ones(_in_dim, dtype=torch.float32)
        target_tensor = torch.zeros(_out_dim, dtype=torch.float32)
        sgd = torch.optim.SGD(weight_dropped_linear.parameters(), lr=0.01)
        _assert_sgd_states(sgd, has_grad=False)

        loss_fn = torch.nn.L1Loss()
        for epoch in range(2):
            sgd.zero_grad()
            output_tensor = weight_dropped_linear(input_tensor)

            # Replaced
            # `assert all(parameter.is_leaf for parameter in weight_dropped_linear.parameters())`
            # with the following `_assert_sgd_states()`.
            # Because weight dropped parameters are not deleted after `forward()`.
            _assert_sgd_states(sgd, has_grad=False if epoch == 0 else True)
            loss_fn(output_tensor, target_tensor).backward()
            _assert_sgd_states(sgd, has_grad=True)
            sgd.step()
            _assert_sgd_states(sgd, has_grad=True)

        # Case 4: When in eval mode
        weight_dropped_linear.eval()
        # The duplicated-non-leaf weight still exists.
        pre_eval_parameters = list(weight_dropped_linear.parameters())
        assert len(pre_eval_parameters) == _n_params + 1
        assert not all(parameter.is_leaf for parameter in pre_eval_parameters)
        weight_dropped_linear(input_tensor)
        # But only the raw weight applies.
        eval_parameters = list(weight_dropped_linear.parameters())
        assert len(eval_parameters) == _n_params
        assert all(parameter.is_leaf for parameter in eval_parameters)
        raw_weight = weight_dropped_linear._module_weight_raw
        weight = weight_dropped_linear._module.weight
        assert torch.equal(raw_weight.T, weight.T)
        assert torch.equal(raw_weight.data, weight.data)