class SimpleDPLSTMTest(unittest.TestCase): def setUp(self): self.SEQ_LENGTH = 20 self.INPUT_DIM = 25 self.MINIBATCH_SIZE = 30 self.LSTM_OUT_DIM = 12 self.NUM_LAYERS = 1 self.bidirectional = False self.batch_first = False self.num_directions = 2 if self.bidirectional else 1 self.h_init = torch.randn( self.NUM_LAYERS * self.num_directions, self.MINIBATCH_SIZE, self.LSTM_OUT_DIM, ) self.c_init = torch.randn( self.NUM_LAYERS * self.num_directions, self.MINIBATCH_SIZE, self.LSTM_OUT_DIM, ) self.original_lstm = LSTM( self.INPUT_DIM, self.LSTM_OUT_DIM, batch_first=self.batch_first, num_layers=self.NUM_LAYERS, bidirectional=self.bidirectional, ) self.dp_lstm = DPLSTM( self.INPUT_DIM, self.LSTM_OUT_DIM, batch_first=self.batch_first, num_layers=self.NUM_LAYERS, bidirectional=self.bidirectional, ) self.dp_lstm.load_state_dict(self.original_lstm.state_dict()) def _reset_seeds(self): torch.manual_seed(1337) torch.cuda.manual_seed(1337) def test_lstm_forward(self): x = ( torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM) if self.batch_first else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM) ) hidden = (self.h_init, self.c_init) out, (hn, cn) = self.original_lstm(x, hidden) dp_out, (dp_hn, dp_cn) = self.dp_lstm(x, hidden) outputs_to_test = [ (out, dp_out, "LSTM and DPLSTM output"), (hn, dp_hn, "LSTM and DPLSTM state `h`"), (cn, dp_cn, "LSTM and DPLSTM state `c`"), ] for output, dp_output, message in outputs_to_test: assert_allclose( actual=dp_output.expand_as(output), expected=output, atol=10e-6, rtol=10e-5, msg=f"Tensor value mismatch between {message}", ) def test_lstm_backward(self): x = ( torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM) if self.batch_first else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM) ) criterion = nn.MSELoss() hidden = (self.h_init, self.c_init) out, (hn, cn) = self.original_lstm(x, hidden) y = torch.zeros_like(out) loss = criterion(out, y) loss.backward() dp_out, (dp_hn, dp_cn) = self.dp_lstm(x, hidden) dp_loss = criterion(dp_out, y) dp_loss.backward() dp_lstm_params = dict(self.dp_lstm.named_parameters()) for param_name, param in self.original_lstm.named_parameters(): dp_param = dp_lstm_params[param_name] assert_allclose( actual=dp_param, expected=param, atol=10e-5, rtol=10e-3, msg=f"Tensor value mismatch in the parameter '{param_name}'", ) assert_allclose( actual=dp_param.grad, expected=param.grad, atol=10e-6, rtol=10e-5, msg=f"Tensor value mismatch in the gradient of parameter '{param_name}'", ) def test_lstm_param_update(self): x = ( torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM) if self.batch_first else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM) ) criterion = nn.MSELoss() optimizer = torch.optim.SGD(self.original_lstm.parameters(), lr=0.5) dp_optimizer = torch.optim.SGD(self.dp_lstm.parameters(), lr=0.5) # Train original LSTM for one step logits, (h_n, c_n) = self.original_lstm(x) y = torch.zeros_like(logits) loss = criterion(logits, y) loss.backward() optimizer.step() # Train DP LSTM for one step dp_logits, (dp_h_n, dp_c_n) = self.dp_lstm(x) dp_loss = criterion(dp_logits, y) dp_loss.backward() dp_optimizer.step() dp_lstm_params = dict(self.dp_lstm.named_parameters()) for param_name, param in self.original_lstm.named_parameters(): dp_param = dp_lstm_params[param_name] assert_allclose( actual=dp_param, expected=param, atol=10e-6, rtol=10e-5, msg=f"Tensor value mismatch in the parameter '{param_name}'", ) assert_allclose( actual=dp_param.grad, expected=param.grad, atol=10e-6, rtol=10e-5, msg=f"Tensor value mismatch in the gradient of parameter '{param_name}'", )
def test_lstm( self, batch_size: int, seq_len: int, emb_size: int, hidden_size: int, num_layers: int, bidirectional: bool, bias: bool, batch_first: bool, zero_init: bool, packed_input_flag: int, ): lstm = nn.LSTM( emb_size, hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional, bias=bias, ) dp_lstm = DPLSTM( emb_size, hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional, bias=bias, ) dp_lstm.load_state_dict(lstm.state_dict()) if packed_input_flag == 0: x = (torch.randn([batch_size, seq_len, emb_size]) if batch_first else torch.randn([seq_len, batch_size, emb_size])) elif packed_input_flag == 1: x = _gen_packed_data(batch_size, seq_len, emb_size, batch_first, sorted_=True) elif packed_input_flag == 2: x = _gen_packed_data(batch_size, seq_len, emb_size, batch_first, sorted_=False) if zero_init: self.compare_forward_outputs( lstm, dp_lstm, x, output_names=("out", "hn", "cn"), atol=1e-5, rtol=1e-3, ) self.compare_gradients( lstm, dp_lstm, lstm_train_fn, x, atol=1e-5, rtol=1e-3, ) else: num_directions = 2 if bidirectional else 1 h0 = torch.randn( [num_layers * num_directions, batch_size, hidden_size]) c0 = torch.randn( [num_layers * num_directions, batch_size, hidden_size]) self.compare_forward_outputs( lstm, dp_lstm, x, (h0, c0), output_names=("out", "hn", "cn"), atol=1e-5, rtol=1e-3, ) self.compare_gradients( lstm, dp_lstm, lstm_train_fn, x, (h0, c0), atol=1e-5, rtol=1e-3, )