示例#1
0
    def test_rnn(
        self,
        model,
        N: int,
        T: int,
        D: int,
        H: int,
        num_layers: int,
        bias: bool,
        batch_first: bool,
        bidirectional: bool,
        using_packed_sequences: bool,
        packed_sequences_sorted: bool,
    ):
        rnn = model(
            D,
            H,
            num_layers=num_layers,
            batch_first=batch_first,
            bias=bias,
            bidirectional=bidirectional,
        )
        rnn = DPRNNAdapter(rnn)

        if using_packed_sequences:
            x = _gen_packed_data(N, T, D, batch_first, packed_sequences_sorted)
        else:
            if batch_first:
                x = torch.randn([N, T, D])
            else:
                x = torch.randn([T, N, D])
        self.run_test(x, rnn, batch_first=batch_first)
示例#2
0
    def test_lstm(
        self,
        N: int,
        T: int,
        D: int,
        H: int,
        num_layers: int,
        bias: bool,
        batch_first: bool,
        bidirectional: bool,
        using_packed_sequences: bool,
        packed_sequences_sorted: bool,
    ):

        lstm = DPSLTMAdapter(
            D,
            H,
            num_layers=num_layers,
            batch_first=batch_first,
            bias=bias,
            bidirectional=bidirectional,
        )
        if using_packed_sequences:
            x = _gen_packed_data(N, T, D, batch_first, packed_sequences_sorted)
        else:
            if batch_first:
                x = torch.randn([N, T, D])
            else:
                x = torch.randn([T, N, D])
        self.run_test(x, lstm, batch_first=batch_first)
示例#3
0
    def test_lstm(
        self,
        batch_size: int,
        seq_len: int,
        emb_size: int,
        hidden_size: int,
        num_layers: int,
        bidirectional: bool,
        bias: bool,
        batch_first: bool,
        zero_init: bool,
        packed_input_flag: int,
    ):
        lstm = nn.LSTM(
            emb_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=bidirectional,
            bias=bias,
        )
        dp_lstm = DPLSTM(
            emb_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=bidirectional,
            bias=bias,
        )

        dp_lstm.load_state_dict(lstm.state_dict())

        if packed_input_flag == 0:
            x = (torch.randn([batch_size, seq_len, emb_size]) if batch_first
                 else torch.randn([seq_len, batch_size, emb_size]))
        elif packed_input_flag == 1:
            x = _gen_packed_data(batch_size,
                                 seq_len,
                                 emb_size,
                                 batch_first,
                                 sorted_=True)
        elif packed_input_flag == 2:
            x = _gen_packed_data(batch_size,
                                 seq_len,
                                 emb_size,
                                 batch_first,
                                 sorted_=False)

        if zero_init:
            self.compare_forward_outputs(
                lstm,
                dp_lstm,
                x,
                output_names=("out", "hn", "cn"),
                atol=1e-5,
                rtol=1e-3,
            )

            self.compare_gradients(
                lstm,
                dp_lstm,
                lstm_train_fn,
                x,
                atol=1e-5,
                rtol=1e-3,
            )

        else:
            num_directions = 2 if bidirectional else 1
            h0 = torch.randn(
                [num_layers * num_directions, batch_size, hidden_size])
            c0 = torch.randn(
                [num_layers * num_directions, batch_size, hidden_size])
            self.compare_forward_outputs(
                lstm,
                dp_lstm,
                x,
                (h0, c0),
                output_names=("out", "hn", "cn"),
                atol=1e-5,
                rtol=1e-3,
            )
            self.compare_gradients(
                lstm,
                dp_lstm,
                lstm_train_fn,
                x,
                (h0, c0),
                atol=1e-5,
                rtol=1e-3,
            )
示例#4
0
    def test_rnn(
        self,
        mode: str,
        batch_size: int,
        seq_len: int,
        emb_size: int,
        hidden_size: int,
        num_layers: int,
        bidirectional: bool,
        bias: bool,
        batch_first: bool,
        zero_init: bool,
        packed_input_flag: int,
    ):
        use_cn = False
        if mode == "rnn":
            original_rnn_class = nn.RNN
            dp_rnn_class = DPRNN
        elif mode == "gru":
            original_rnn_class = nn.GRU
            dp_rnn_class = DPGRU
        elif mode == "lstm":
            original_rnn_class = nn.LSTM
            dp_rnn_class = DPLSTM
            use_cn = True
        else:
            raise ValueError("Invalid RNN mode")

        rnn = original_rnn_class(
            emb_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=bidirectional,
            bias=bias,
        )
        dp_rnn = dp_rnn_class(
            emb_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=bidirectional,
            bias=bias,
        )

        dp_rnn.load_state_dict(rnn.state_dict())

        if packed_input_flag == 0:
            x = (torch.randn([batch_size, seq_len, emb_size]) if batch_first
                 else torch.randn([seq_len, batch_size, emb_size]))
        elif packed_input_flag == 1:
            x = _gen_packed_data(batch_size,
                                 seq_len,
                                 emb_size,
                                 batch_first,
                                 sorted_=True)
        elif packed_input_flag == 2:
            x = _gen_packed_data(batch_size,
                                 seq_len,
                                 emb_size,
                                 batch_first,
                                 sorted_=False)
        else:
            raise ValueError("Invalid packed input flag")

        if zero_init:
            self.compare_forward_outputs(
                rnn,
                dp_rnn,
                x,
                output_names=("out", "hn", "cn") if use_cn else ("out", "hn"),
                atol=1e-5,
                rtol=1e-3,
            )

            self.compare_gradients(
                rnn,
                dp_rnn,
                rnn_train_fn,
                x,
                atol=1e-5,
                rtol=1e-3,
            )

        else:
            num_directions = 2 if bidirectional else 1
            h0 = torch.randn(
                [num_layers * num_directions, batch_size, hidden_size])
            c0 = torch.randn(
                [num_layers * num_directions, batch_size, hidden_size])
            self.compare_forward_outputs(
                rnn,
                dp_rnn,
                x,
                (h0, c0) if use_cn else h0,
                output_names=("out", "hn", "cn") if use_cn else ("out", "hn"),
                atol=1e-5,
                rtol=1e-3,
            )
            self.compare_gradients(
                rnn,
                dp_rnn,
                rnn_train_fn,
                x,
                (h0, c0) if use_cn else h0,
                atol=1e-5,
                rtol=1e-3,
            )