Beispiel #1
0
class SimpleDPLSTMTest(unittest.TestCase):
    def setUp(self):
        self.SEQ_LENGTH = 20
        self.INPUT_DIM = 25
        self.MINIBATCH_SIZE = 30
        self.LSTM_OUT_DIM = 12
        self.NUM_LAYERS = 1
        self.bidirectional = False
        self.batch_first = False

        self.num_directions = 2 if self.bidirectional else 1
        self.h_init = torch.randn(
            self.NUM_LAYERS * self.num_directions,
            self.MINIBATCH_SIZE,
            self.LSTM_OUT_DIM,
        )
        self.c_init = torch.randn(
            self.NUM_LAYERS * self.num_directions,
            self.MINIBATCH_SIZE,
            self.LSTM_OUT_DIM,
        )

        self.original_lstm = LSTM(
            self.INPUT_DIM,
            self.LSTM_OUT_DIM,
            batch_first=self.batch_first,
            num_layers=self.NUM_LAYERS,
            bidirectional=self.bidirectional,
        )
        self.dp_lstm = DPLSTM(
            self.INPUT_DIM,
            self.LSTM_OUT_DIM,
            batch_first=self.batch_first,
            num_layers=self.NUM_LAYERS,
            bidirectional=self.bidirectional,
        )

        self.dp_lstm.load_state_dict(self.original_lstm.state_dict())

    def _reset_seeds(self):
        torch.manual_seed(1337)
        torch.cuda.manual_seed(1337)

    def test_lstm_forward(self):
        x = (
            torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM)
            if self.batch_first
            else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM)
        )
        hidden = (self.h_init, self.c_init)

        out, (hn, cn) = self.original_lstm(x, hidden)
        dp_out, (dp_hn, dp_cn) = self.dp_lstm(x, hidden)

        outputs_to_test = [
            (out, dp_out, "LSTM and DPLSTM output"),
            (hn, dp_hn, "LSTM and DPLSTM state `h`"),
            (cn, dp_cn, "LSTM and DPLSTM state `c`"),
        ]

        for output, dp_output, message in outputs_to_test:
            assert_allclose(
                actual=dp_output.expand_as(output),
                expected=output,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch between {message}",
            )

    def test_lstm_backward(self):
        x = (
            torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM)
            if self.batch_first
            else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM)
        )
        criterion = nn.MSELoss()

        hidden = (self.h_init, self.c_init)

        out, (hn, cn) = self.original_lstm(x, hidden)
        y = torch.zeros_like(out)
        loss = criterion(out, y)
        loss.backward()

        dp_out, (dp_hn, dp_cn) = self.dp_lstm(x, hidden)
        dp_loss = criterion(dp_out, y)
        dp_loss.backward()

        dp_lstm_params = dict(self.dp_lstm.named_parameters())
        for param_name, param in self.original_lstm.named_parameters():
            dp_param = dp_lstm_params[param_name]
            assert_allclose(
                actual=dp_param,
                expected=param,
                atol=10e-5,
                rtol=10e-3,
                msg=f"Tensor value mismatch in the parameter '{param_name}'",
            )
            assert_allclose(
                actual=dp_param.grad,
                expected=param.grad,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch in the gradient of parameter '{param_name}'",
            )

    def test_lstm_param_update(self):
        x = (
            torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM)
            if self.batch_first
            else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM)
        )
        criterion = nn.MSELoss()

        optimizer = torch.optim.SGD(self.original_lstm.parameters(), lr=0.5)
        dp_optimizer = torch.optim.SGD(self.dp_lstm.parameters(), lr=0.5)

        # Train original LSTM for one step
        logits, (h_n, c_n) = self.original_lstm(x)
        y = torch.zeros_like(logits)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        # Train DP LSTM for one step
        dp_logits, (dp_h_n, dp_c_n) = self.dp_lstm(x)
        dp_loss = criterion(dp_logits, y)
        dp_loss.backward()
        dp_optimizer.step()

        dp_lstm_params = dict(self.dp_lstm.named_parameters())
        for param_name, param in self.original_lstm.named_parameters():
            dp_param = dp_lstm_params[param_name]
            assert_allclose(
                actual=dp_param,
                expected=param,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch in the parameter '{param_name}'",
            )
            assert_allclose(
                actual=dp_param.grad,
                expected=param.grad,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch in the gradient of parameter '{param_name}'",
            )
Beispiel #2
0
    def test_lstm(
        self,
        batch_size: int,
        seq_len: int,
        emb_size: int,
        hidden_size: int,
        num_layers: int,
        bidirectional: bool,
        bias: bool,
        batch_first: bool,
        zero_init: bool,
        packed_input_flag: int,
    ):
        lstm = nn.LSTM(
            emb_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=bidirectional,
            bias=bias,
        )
        dp_lstm = DPLSTM(
            emb_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=bidirectional,
            bias=bias,
        )

        dp_lstm.load_state_dict(lstm.state_dict())

        if packed_input_flag == 0:
            x = (torch.randn([batch_size, seq_len, emb_size]) if batch_first
                 else torch.randn([seq_len, batch_size, emb_size]))
        elif packed_input_flag == 1:
            x = _gen_packed_data(batch_size,
                                 seq_len,
                                 emb_size,
                                 batch_first,
                                 sorted_=True)
        elif packed_input_flag == 2:
            x = _gen_packed_data(batch_size,
                                 seq_len,
                                 emb_size,
                                 batch_first,
                                 sorted_=False)

        if zero_init:
            self.compare_forward_outputs(
                lstm,
                dp_lstm,
                x,
                output_names=("out", "hn", "cn"),
                atol=1e-5,
                rtol=1e-3,
            )

            self.compare_gradients(
                lstm,
                dp_lstm,
                lstm_train_fn,
                x,
                atol=1e-5,
                rtol=1e-3,
            )

        else:
            num_directions = 2 if bidirectional else 1
            h0 = torch.randn(
                [num_layers * num_directions, batch_size, hidden_size])
            c0 = torch.randn(
                [num_layers * num_directions, batch_size, hidden_size])
            self.compare_forward_outputs(
                lstm,
                dp_lstm,
                x,
                (h0, c0),
                output_names=("out", "hn", "cn"),
                atol=1e-5,
                rtol=1e-3,
            )
            self.compare_gradients(
                lstm,
                dp_lstm,
                lstm_train_fn,
                x,
                (h0, c0),
                atol=1e-5,
                rtol=1e-3,
            )