def test_rnn( self, model, N: int, T: int, D: int, H: int, num_layers: int, bias: bool, batch_first: bool, bidirectional: bool, using_packed_sequences: bool, packed_sequences_sorted: bool, ): rnn = model( D, H, num_layers=num_layers, batch_first=batch_first, bias=bias, bidirectional=bidirectional, ) rnn = DPRNNAdapter(rnn) if using_packed_sequences: x = _gen_packed_data(N, T, D, batch_first, packed_sequences_sorted) else: if batch_first: x = torch.randn([N, T, D]) else: x = torch.randn([T, N, D]) self.run_test(x, rnn, batch_first=batch_first)
def test_lstm( self, N: int, T: int, D: int, H: int, num_layers: int, bias: bool, batch_first: bool, bidirectional: bool, using_packed_sequences: bool, packed_sequences_sorted: bool, ): lstm = DPSLTMAdapter( D, H, num_layers=num_layers, batch_first=batch_first, bias=bias, bidirectional=bidirectional, ) if using_packed_sequences: x = _gen_packed_data(N, T, D, batch_first, packed_sequences_sorted) else: if batch_first: x = torch.randn([N, T, D]) else: x = torch.randn([T, N, D]) self.run_test(x, lstm, batch_first=batch_first)
def test_lstm( self, batch_size: int, seq_len: int, emb_size: int, hidden_size: int, num_layers: int, bidirectional: bool, bias: bool, batch_first: bool, zero_init: bool, packed_input_flag: int, ): lstm = nn.LSTM( emb_size, hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional, bias=bias, ) dp_lstm = DPLSTM( emb_size, hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional, bias=bias, ) dp_lstm.load_state_dict(lstm.state_dict()) if packed_input_flag == 0: x = (torch.randn([batch_size, seq_len, emb_size]) if batch_first else torch.randn([seq_len, batch_size, emb_size])) elif packed_input_flag == 1: x = _gen_packed_data(batch_size, seq_len, emb_size, batch_first, sorted_=True) elif packed_input_flag == 2: x = _gen_packed_data(batch_size, seq_len, emb_size, batch_first, sorted_=False) if zero_init: self.compare_forward_outputs( lstm, dp_lstm, x, output_names=("out", "hn", "cn"), atol=1e-5, rtol=1e-3, ) self.compare_gradients( lstm, dp_lstm, lstm_train_fn, x, atol=1e-5, rtol=1e-3, ) else: num_directions = 2 if bidirectional else 1 h0 = torch.randn( [num_layers * num_directions, batch_size, hidden_size]) c0 = torch.randn( [num_layers * num_directions, batch_size, hidden_size]) self.compare_forward_outputs( lstm, dp_lstm, x, (h0, c0), output_names=("out", "hn", "cn"), atol=1e-5, rtol=1e-3, ) self.compare_gradients( lstm, dp_lstm, lstm_train_fn, x, (h0, c0), atol=1e-5, rtol=1e-3, )
def test_rnn( self, mode: str, batch_size: int, seq_len: int, emb_size: int, hidden_size: int, num_layers: int, bidirectional: bool, bias: bool, batch_first: bool, zero_init: bool, packed_input_flag: int, ): use_cn = False if mode == "rnn": original_rnn_class = nn.RNN dp_rnn_class = DPRNN elif mode == "gru": original_rnn_class = nn.GRU dp_rnn_class = DPGRU elif mode == "lstm": original_rnn_class = nn.LSTM dp_rnn_class = DPLSTM use_cn = True else: raise ValueError("Invalid RNN mode") rnn = original_rnn_class( emb_size, hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional, bias=bias, ) dp_rnn = dp_rnn_class( emb_size, hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=bidirectional, bias=bias, ) dp_rnn.load_state_dict(rnn.state_dict()) if packed_input_flag == 0: x = (torch.randn([batch_size, seq_len, emb_size]) if batch_first else torch.randn([seq_len, batch_size, emb_size])) elif packed_input_flag == 1: x = _gen_packed_data(batch_size, seq_len, emb_size, batch_first, sorted_=True) elif packed_input_flag == 2: x = _gen_packed_data(batch_size, seq_len, emb_size, batch_first, sorted_=False) else: raise ValueError("Invalid packed input flag") if zero_init: self.compare_forward_outputs( rnn, dp_rnn, x, output_names=("out", "hn", "cn") if use_cn else ("out", "hn"), atol=1e-5, rtol=1e-3, ) self.compare_gradients( rnn, dp_rnn, rnn_train_fn, x, atol=1e-5, rtol=1e-3, ) else: num_directions = 2 if bidirectional else 1 h0 = torch.randn( [num_layers * num_directions, batch_size, hidden_size]) c0 = torch.randn( [num_layers * num_directions, batch_size, hidden_size]) self.compare_forward_outputs( rnn, dp_rnn, x, (h0, c0) if use_cn else h0, output_names=("out", "hn", "cn") if use_cn else ("out", "hn"), atol=1e-5, rtol=1e-3, ) self.compare_gradients( rnn, dp_rnn, rnn_train_fn, x, (h0, c0) if use_cn else h0, atol=1e-5, rtol=1e-3, )