def setUp(self): self._vocab_size = 10 self._max_time = 16 self._batch_size = 8 self._emb_dim = 20 self._attention_dim = 256 self._inputs = torch.randint(self._vocab_size, size=(self._batch_size, self._max_time)) embedding = torch.rand(self._vocab_size, self._emb_dim, dtype=torch.float) self._embedder = WordEmbedder(init_value=embedding) self._encoder_output = torch.rand(self._batch_size, self._max_time, 64) self._test_hparams = {} # (cell_type, is_multi) -> hparams for cell_type in ["RNNCell", "LSTMCell", "GRUCell"]: hparams = { "rnn_cell": { 'type': cell_type, 'kwargs': { 'num_units': 256, }, }, "attention": { "kwargs": { "num_units": self._attention_dim }, } } self._test_hparams[(cell_type, False)] = HParams( hparams, AttentionRNNDecoder.default_hparams()) hparams = { "rnn_cell": { 'type': 'LSTMCell', 'kwargs': { 'num_units': 256, }, 'num_layers': 3, }, "attention": { "kwargs": { "num_units": self._attention_dim }, } } self._test_hparams[("LSTMCell", True)] = HParams( hparams, AttentionRNNDecoder.default_hparams())
def setUp(self): self._vocab_size = 10 self._max_time = 16 self._batch_size = 8 self._emb_dim = 20 self._attention_dim = 256 self._inputs = torch.rand(self._batch_size, self._max_time, self._emb_dim, dtype=torch.float32) self._embedding = torch.rand(self._vocab_size, self._emb_dim, dtype=torch.float32) self._encoder_output = torch.rand(self._batch_size, self._max_time, 64) hparams = { "rnn_cell": { 'type': 'RNNCell', 'kwargs': { 'num_units': 256, }, }, "attention": { "kwargs": { "num_units": self._attention_dim }, } } self._hparams_rnn = HParams(hparams, AttentionRNNDecoder.default_hparams()) hparams = { "rnn_cell": { 'type': 'LSTMCell', 'kwargs': { 'num_units': 256, }, }, "attention": { "kwargs": { "num_units": self._attention_dim }, } } self._hparams_lstm = HParams(hparams, AttentionRNNDecoder.default_hparams()) hparams = { "rnn_cell": { 'type': 'GRUCell', 'kwargs': { 'num_units': 256, }, }, "attention": { "kwargs": { "num_units": self._attention_dim }, } } self._hparams_gru = HParams(hparams, AttentionRNNDecoder.default_hparams()) hparams = { "rnn_cell": { 'type': 'RNNCell', 'kwargs': { 'num_units': 256, }, 'num_layers': 3, }, "attention": { "kwargs": { "num_units": self._attention_dim }, } } self._hparams_multicell = HParams( hparams, AttentionRNNDecoder.default_hparams())