def __init__(self, input_size: int, cell_fw: Optional[RNNCellBase[State]] = None, cell_bw: Optional[RNNCellBase[State]] = None, output_layer_fw: Optional[nn.Module] = None, output_layer_bw: Optional[nn.Module] = None, hparams=None): super().__init__(hparams=hparams) # Make RNN cells if cell_fw is not None: self._cell_fw = cell_fw else: self._cell_fw = layers.get_rnn_cell(input_size, self._hparams.rnn_cell_fw) if cell_bw is not None: self._cell_bw = cell_bw elif self._hparams.rnn_cell_share_config: self._cell_bw = layers.get_rnn_cell(input_size, self._hparams.rnn_cell_fw) else: self._cell_bw = layers.get_rnn_cell(input_size, self._hparams.rnn_cell_bw) # Make output layers self.__output_layer_fw: Optional[nn.Module] if output_layer_fw is not None: self._output_layer_fw = output_layer_fw self._output_layer_hparams_fw = None else: self._output_layer_fw = _build_dense_output_layer( # type: ignore self._cell_fw.hidden_size, self._hparams.output_layer_fw) self._output_layer_hparams_fw = self._hparams.output_layer_fw self.__output_layer_bw: Optional[nn.Module] if output_layer_bw is not None: self._output_layer_bw = output_layer_bw self._output_layer_hparams_bw = None elif self._hparams.output_layer_share_config: self._output_layer_bw = _build_dense_output_layer( # type: ignore self._cell_bw.hidden_size, self._hparams.output_layer_fw) self._output_layer_hparams_bw = self._hparams.output_layer_fw else: self._output_layer_bw = _build_dense_output_layer( # type: ignore self._cell_bw.hidden_size, self._hparams.output_layer_bw) self._output_layer_hparams_bw = self._hparams.output_layer_bw
def __init__(self, input_size: int, vocab_size: int, token_embedder: Optional[TokenEmbedder] = None, token_pos_embedder: Optional[TokenPosEmbedder] = None, cell: Optional[RNNCellBase] = None, output_layer: Optional[nn.Module] = None, input_time_major: bool = False, output_time_major: bool = False, hparams=None): super().__init__(token_embedder, token_pos_embedder, input_time_major, output_time_major, hparams=hparams) self._input_size = input_size self._vocab_size = vocab_size # Make RNN cell self._cell = cell or layers.get_rnn_cell(input_size, self._hparams.rnn_cell) self._beam_search_cell = None # Make the output layer self._output_layer, _ = _make_output_layer( output_layer, self._vocab_size, self._cell.hidden_size, self._hparams.output_layer_bias)
def __init__(self, *args, **kwargs): super(TestConnectors, self).__init__(*args, **kwargs) self._batch_size = 100 self._decoder_cell = layers.get_rnn_cell( 256, layers.default_rnn_cell_hparams())
def __init__(self, input_size: int, cell: Optional[RNNCellBase[State]] = None, output_layer: Optional[nn.Module] = None, hparams=None): super().__init__(hparams=hparams) # Make RNN cell if cell is not None: self._cell = cell else: self._cell = layers.get_rnn_cell(input_size, self._hparams.rnn_cell) # Make output layer self._output_layer: Optional[nn.Module] if output_layer is not None: self._output_layer = output_layer self._output_layer_hparams = None else: self._output_layer = _build_dense_output_layer( self._cell.hidden_size, self._hparams.output_layer) self._output_layer_hparams = self._hparams.output_layer
def setUp(self) -> None: self._batch_size = 100 self._decoder_cell = layers.get_rnn_cell( 256, layers.default_rnn_cell_hparams())
def __init__(self, input_size: int, encoder_output_size: int, vocab_size: int, token_embedder: Optional[TokenEmbedder] = None, token_pos_embedder: Optional[TokenPosEmbedder] = None, cell: Optional[RNNCellBase] = None, output_layer: Optional[Union[nn.Module, torch.Tensor]] = None, cell_input_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, hparams=None): super().__init__(input_size, vocab_size, token_embedder, token_pos_embedder, cell=cell, output_layer=output_layer, hparams=hparams) attn_hparams = self._hparams['attention'] attn_kwargs = attn_hparams['kwargs'].todict() # Compute the correct input_size internally. if cell is None: if cell_input_fn is None: if attn_hparams["attention_layer_size"] is None: input_size += encoder_output_size else: input_size += attn_hparams["attention_layer_size"] else: if attn_hparams["attention_layer_size"] is None: input_size = cell_input_fn( torch.empty(input_size), torch.empty(encoder_output_size)).shape[-1] else: input_size = cell_input_fn( torch.empty(input_size), torch.empty( attn_hparams["attention_layer_size"])).shape[-1] self._cell = layers.get_rnn_cell(input_size, self._hparams.rnn_cell) # Parse the `probability_fn` argument if 'probability_fn' in attn_kwargs: prob_fn = attn_kwargs['probability_fn'] if prob_fn is not None and not callable(prob_fn): prob_fn = get_function( prob_fn, ['torch.nn.functional', 'texar.torch.core']) attn_kwargs['probability_fn'] = prob_fn # Parse `encoder_output_size` and `decoder_output_size` arguments if attn_hparams['type'] in [ 'BahdanauAttention', 'BahdanauMonotonicAttention' ]: attn_kwargs.update({"decoder_output_size": self._cell.hidden_size}) attn_kwargs.update({"encoder_output_size": encoder_output_size}) attn_modules = ['texar.torch.core'] # TODO: Support multiple attention mechanisms. self.attention_mechanism: AttentionMechanism self.attention_mechanism = check_or_get_instance( attn_hparams["type"], attn_kwargs, attn_modules, classtype=AttentionMechanism) self._attn_cell_kwargs = { "attention_layer_size": attn_hparams["attention_layer_size"], "alignment_history": attn_hparams["alignment_history"], "output_attention": attn_hparams["output_attention"], } self._cell_input_fn = cell_input_fn if attn_hparams["output_attention"] and vocab_size is not None and \ self.attention_mechanism is not None: if attn_hparams["attention_layer_size"] is None: self._output_layer = nn.Linear(encoder_output_size, vocab_size) else: self._output_layer = nn.Linear( sum(attn_hparams["attention_layer_size"]) if isinstance( attn_hparams["attention_layer_size"], list) else attn_hparams["attention_layer_size"], vocab_size) attn_cell = AttentionWrapper(self._cell, self.attention_mechanism, cell_input_fn=self._cell_input_fn, **self._attn_cell_kwargs) self._cell: AttentionWrapper = attn_cell self.memory: Optional[torch.Tensor] = None self.memory_sequence_length: Optional[torch.LongTensor] = None
def test_get_rnn_cell(self): r"""Tests the HParams class. """ input_size = 10 hparams = { 'type': 'LSTMCell', 'kwargs': { 'num_units': 20, 'forget_bias': 1.0, }, 'num_layers': 3, 'dropout': { 'input_keep_prob': 0.5, 'output_keep_prob': 0.5, 'state_keep_prob': 0.5, 'variational_recurrent': True }, 'residual': True, 'highway': True, } hparams = HParams(hparams, default_rnn_cell_hparams()) rnn_cell = get_rnn_cell(input_size, hparams) self.assertIsInstance(rnn_cell, wrappers.MultiRNNCell) self.assertEqual(len(rnn_cell._cell), hparams.num_layers) self.assertEqual(rnn_cell.input_size, input_size) self.assertEqual(rnn_cell.hidden_size, hparams.kwargs.num_units) for idx, cell in enumerate(rnn_cell._cell): layer_input_size = (input_size if idx == 0 else hparams.kwargs.num_units) self.assertEqual(cell.input_size, layer_input_size) self.assertEqual(cell.hidden_size, hparams.kwargs.num_units) if idx > 0: highway = cell residual = highway._cell dropout = residual._cell self.assertIsInstance(highway, wrappers.HighwayWrapper) self.assertIsInstance(residual, wrappers.ResidualWrapper) else: dropout = cell lstm = dropout._cell builtin_lstm = lstm._cell self.assertIsInstance(dropout, wrappers.DropoutWrapper) self.assertIsInstance(lstm, wrappers.LSTMCell) self.assertIsInstance(builtin_lstm, nn.LSTMCell) h = hparams.kwargs.num_units forget_bias = builtin_lstm.bias_ih[h:(2 * h)] self.assertTrue((forget_bias == hparams.kwargs.forget_bias).all()) for key in ['input', 'output', 'state']: self.assertEqual(getattr(dropout, f'_{key}_keep_prob'), hparams.dropout[f'{key}_keep_prob']) self.assertTrue(dropout._variational_recurrent) batch_size = 8 seq_len = 6 state = None for step in range(seq_len): input = torch.zeros(batch_size, input_size) output, state = rnn_cell(input, state) self.assertEqual(output.shape, (batch_size, hparams.kwargs.num_units)) self.assertEqual(len(state), hparams.num_layers) utils.map_structure( lambda s: self.assertEqual(s.shape, (batch_size, hparams.kwargs .num_units)), state)