Esempio n. 1
0
    def setUp(self):
        self._vocab_size = 10
        self._max_time = 16
        self._batch_size = 8
        self._emb_dim = 20
        self._attention_dim = 256
        self._inputs = torch.randint(self._vocab_size,
                                     size=(self._batch_size, self._max_time))
        embedding = torch.rand(self._vocab_size,
                               self._emb_dim,
                               dtype=torch.float)
        self._embedder = WordEmbedder(init_value=embedding)
        self._encoder_output = torch.rand(self._batch_size, self._max_time, 64)

        self._test_hparams = {}  # (cell_type, is_multi) -> hparams
        for cell_type in ["RNNCell", "LSTMCell", "GRUCell"]:
            hparams = {
                "rnn_cell": {
                    'type': cell_type,
                    'kwargs': {
                        'num_units': 256,
                    },
                },
                "attention": {
                    "kwargs": {
                        "num_units": self._attention_dim
                    },
                }
            }
            self._test_hparams[(cell_type, False)] = HParams(
                hparams, AttentionRNNDecoder.default_hparams())

        hparams = {
            "rnn_cell": {
                'type': 'LSTMCell',
                'kwargs': {
                    'num_units': 256,
                },
                'num_layers': 3,
            },
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                },
            }
        }
        self._test_hparams[("LSTMCell", True)] = HParams(
            hparams, AttentionRNNDecoder.default_hparams())
    def setUp(self):
        self._vocab_size = 10
        self._max_time = 16
        self._batch_size = 8
        self._emb_dim = 20
        self._attention_dim = 256
        self._inputs = torch.rand(self._batch_size,
                                  self._max_time,
                                  self._emb_dim,
                                  dtype=torch.float32)
        self._embedding = torch.rand(self._vocab_size,
                                     self._emb_dim,
                                     dtype=torch.float32)
        self._encoder_output = torch.rand(self._batch_size, self._max_time, 64)
        hparams = {
            "rnn_cell": {
                'type': 'RNNCell',
                'kwargs': {
                    'num_units': 256,
                },
            },
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                },
            }
        }
        self._hparams_rnn = HParams(hparams,
                                    AttentionRNNDecoder.default_hparams())

        hparams = {
            "rnn_cell": {
                'type': 'LSTMCell',
                'kwargs': {
                    'num_units': 256,
                },
            },
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                },
            }
        }
        self._hparams_lstm = HParams(hparams,
                                     AttentionRNNDecoder.default_hparams())

        hparams = {
            "rnn_cell": {
                'type': 'GRUCell',
                'kwargs': {
                    'num_units': 256,
                },
            },
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                },
            }
        }
        self._hparams_gru = HParams(hparams,
                                    AttentionRNNDecoder.default_hparams())

        hparams = {
            "rnn_cell": {
                'type': 'RNNCell',
                'kwargs': {
                    'num_units': 256,
                },
                'num_layers': 3,
            },
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                },
            }
        }
        self._hparams_multicell = HParams(
            hparams, AttentionRNNDecoder.default_hparams())