def test_RNNDecoder_invalid_type(): with pytest.raises(ValueError): RNNDecoder(10, 12, rnn_type="foo")
def __init__(self, atype): super().__init__() self.decoder = RNNDecoder(50, 128, att_conf=dict(atype=atype))