예제 #1
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 target_namespace: str = "tokens",
                 attention_function: SimilarityFunction = None,
                 scheduled_sampling_ratio: float = 0.0,
                 label_smoothing: float = None,
                 target_embedding_dim: int = None,
                 target_tokens_embedder: TokenEmbedder = None) -> None:
        super(PretrSeq2Seq, self).__init__(vocab)
        self._label_smoothing = label_smoothing
        self._source_embedder = source_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._attention_function = attention_function
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder. Also, if
        # we're using attention with ``DotProductSimilarity``, this is needed.
        self._decoder_output_dim = self._encoder.get_output_dim()

        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim(
        )
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # PRETRAINED PART
        if target_tokens_embedder:
            target_embedding_dim = target_tokens_embedder.get_output_dim()
            self._target_embedder = target_tokens_embedder

        if self._attention_function:
            self._decoder_attention = LegacyAttention(self._attention_function)
            # The output of attention, a weighted average over encoder outputs, will be
            # concatenated to the input vector of the decoder at each time step.
            self._decoder_input_dim = self._encoder.get_output_dim(
            ) + target_embedding_dim
        else:
            self._decoder_input_dim = target_embedding_dim
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)
예제 #2
0
    def __init__(self,
                 input_dim: int,
                 decoder_rnn_output_dim: int,
                 output_projection_dim: int,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 scheduled_sampling_ratio: float = 0.) -> None:
        super(AttentionalRnnDecoder, self).__init__()
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._input_dim = input_dim
        self._output_dim = output_projection_dim

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError(
                    "You can only specify an attention module or an "
                    "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            decoder_rnn_input_dim = input_dim + output_projection_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            decoder_rnn_input_dim = output_projection_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(decoder_rnn_input_dim,
                                      decoder_rnn_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(decoder_rnn_output_dim,
                                               output_projection_dim)

        # At prediction time, we can use a beam search to find the most likely sequence of target tokens.
        # If the beam_size parameter is not given, we'll just use a greedy search (equivalent to beam_size = 1).
        self._max_decoding_steps = max_decoding_steps
예제 #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.5) -> None:
        super(Seq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        self._token_based_metric = TokenSequenceAccuracy()

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError("You can only specify an attention module or an "
                                         "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
예제 #4
0
    def __init__(
        self,
        vocab: Vocabulary,
        encoder: Seq2SeqEncoder,
        input_size: int,
        target_embedding_dim: int,
        decoder_hidden_dim: int,
        max_decoding_steps: int,
        max_decoding_ratio: float = 1.5,
        dep_parser: Model = None,
        pos_tagger: Model = None,
        cmvn: str = 'none',
        delta: int = 0,
        time_mask_width: int = 0,
        freq_mask_width: int = 0,
        time_mask_max_ratio: float = 0.0,
        dec_layers: int = 1,
        layerwise_pretraining: List[Tuple[int, int]] = None,
        cnn: Seq2SeqEncoder = None,
        conv_lstm: Seq2SeqEncoder = None,
        train_at_phn_level: bool = False,
        rnnt_layer: Model = None,
        phn_ctc_layer: Model = None,
        ctc_layer: Model = None,
        projection_layer: nn.Module = None,
        tie_proj: bool = False,
        att_ratio: float = 0.0,
        dep_ratio: float = 0.0,
        pos_ratio: float = 0.0,
        attention: Attention = None,
        attention_function: SimilarityFunction = None,
        latency_penalty: float = 0.0,
        loss_type: str = "nll",
        beam_size: int = 1,
        target_namespace: str = "tokens",
        phoneme_target_namespace: str = "phonemes",
        dropout: float = 0.0,
        blank: str = "_",
        sampling_strategy: str = "max",
        from_candidates: bool = False,
        scheduled_sampling_ratio: float = 0.,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super(PhnMoChA, self).__init__(vocab)
        self._input_size = input_size
        self._target_namespace = target_namespace
        self._phn_target_namespace = phoneme_target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._sampling_strategy = sampling_strategy
        self._train_at_phn_level = train_at_phn_level
        self._blank = blank

        self._dep_parser = dep_parser
        self._pos_tagger = pos_tagger
        self._ctc_layer = ctc_layer
        self._rnnt_layer = rnnt_layer
        self._phn_ctc_layer = phn_ctc_layer
        self._projection_layer = projection_layer
        if tie_proj:
            self._rnnt_layer.set_projection_layer(projection_layer)
        self._att_ratio = att_ratio
        self._dep_ratio = dep_ratio
        self._pos_ratio = pos_ratio
        self._loss_type = loss_type
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)  # pylint: disable=protected-access
        self._phn_pad_index = self.vocab.get_token_index(
            self.vocab._padding_token, self._phn_target_namespace)  # pylint: disable=protected-access

        exclude_indices = {self._pad_index, self._end_index, self._start_index}

        self._logs: Dict[str, Union[Metric, None]] = {
            "att_wer": (WER(exclude_indices=exclude_indices)
                        if self._att_ratio > 0 else None),
            "att_bleu": (BLEU(exclude_indices=exclude_indices)
                         if self._att_ratio > 0 else None),
            "att_loss": (Average() if self._att_ratio > 0 else None),
            "phn_ctc_loss": (Average() if self._phn_ctc_layer else None),
            "ctc_loss": (Average() if self._ctc_layer else None),
            "rnnt_loss": (Average() if self._rnnt_layer else None),
            "dal_loss": (Average() if latency_penalty > 0.0 else None),
            "dep_loss": (Average() if self._dep_parser else None),
            "pos_loss": (Average() if self._pos_tagger else None),
            "tag_loss": (Average() if self._dep_parser else None),
            "arc_loss": (Average() if self._dep_parser else None)
        }

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        self._max_decoding_steps = max_decoding_steps
        self._max_decoding_ratio = max_decoding_ratio
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        self._cnn = cnn
        self._conv_lstm = conv_lstm

        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        self._num_classes = num_classes

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError(
                    "You can only specify an attention module or an "
                    "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = decoder_hidden_dim
        self._dec_layers = dec_layers
        if self._decoder_output_dim != self._encoder_output_dim:
            self.bridge = nn.Linear(self._encoder_output_dim,
                                    self._dec_layers *
                                    self._decoder_output_dim,
                                    bias=False)

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
            self.att_out = Linear(self._decoder_output_dim +
                                  self._encoder_output_dim,
                                  self._decoder_output_dim,
                                  bias=True)
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder = nn.LSTM(self._decoder_input_dim,
                                self._decoder_output_dim,
                                num_layers=self._dec_layers,
                                batch_first=True)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

        self._input_norm = lambda x: x
        if cmvn == 'global':
            self._input_norm = nn.BatchNorm1d(self._input_size * (delta + 1))
        elif cmvn == 'utt':
            self._input_norm = nn.InstanceNorm1d(self._input_size *
                                                 (delta + 1))

        self._delta = None
        if delta > 0:
            self._delta = Delta(order=delta)

        self._epoch_num = float("inf")
        self._layerwise_pretraining = layerwise_pretraining
        try:
            if isinstance(self._encoder, PytorchSeq2SeqWrapper):
                self._num_layers = self._encoder._module.num_layers
            else:
                self._num_layers = self._encoder.num_layers
        except AttributeError:
            self._num_layers = float("inf")

        self._output_layer_num = self._num_layers

        self._loss = None

        self._from_candidates = from_candidates
        if loss_type == "ocd":
            self._loss = OCDLoss(self._end_index, 1e-7, 1e-7, 5)
        elif loss_type == "edocd":
            self._loss = EDOCDLoss(self._end_index, 1e-7, 1e-7, 5)

        self._latency_penalty = latency_penalty
        self._target_granularity = self._target_namespace

        self.time_mask = TimeMask(time_mask_width, time_mask_max_ratio)
        self.freq_mask = FreqMask(freq_mask_width)

        initializer(self)
    def __init__(
        self,
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        max_decoding_steps: int,
        target_namespace: str = "tokens",
        target_embedding_dim: int = None,
        attention_function: SimilarityFunction = None,
        scheduled_sampling_ratio: float = 0.0,
        weight_function="softmax",
        gumbel_tau: float = 0.66,
        gumbel_hard: bool = True,
        gumbel_eps: float = 1e-10,
        infer_with: str = "distribution",
        self_feed_with: str = "argmax_distribution",
    ) -> None:
        super(Rnn2RnnDifferentiableNll, self).__init__(vocab)
        self._source_embedder = source_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._attention_function = attention_function
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder. Also, if
        # we're using attention with ``DotProductSimilarity``, this is needed.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim(
        )
        self._target_embedder = Embedding(num_classes, target_embedding_dim)
        if self._attention_function:
            self._decoder_attention = LegacyAttention(self._attention_function)
            # The output of attention, a weighted average over encoder outputs, will be
            # concatenated to the input vector of the decoder at each time step.
            self._decoder_input_dim = self._encoder.get_output_dim(
            ) + target_embedding_dim
        else:
            self._decoder_input_dim = target_embedding_dim
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

        self._weights_calculation_function = weight_function

        self._gumbel_tau = gumbel_tau
        self._gumbel_hard = gumbel_hard
        self._gamble_eps = gumbel_eps

        if self_feed_with not in {
                "distribution", "argmax_logits", "argmax_distribution",
                "detach_distribution"
        }:
            raise ValueError(
                "Allowed vals for selffeed are {distribution, argmax_logits, argmax_distribution, detach_distribution}"
            )

        if infer_with not in {
                "distribution", "argmax_logits", "argmax_distribution"
        }:
            raise ValueError(
                "Allowed vals for ifer_with are {distribution, argmax_logits, argmax_distribution}"
            )

        self._infer_with = infer_with
        self._self_feed_with = self_feed_with
예제 #6
0
파일: model.py 프로젝트: vivi0204/-
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 intent_encoder: Seq2SeqEncoder = None,
                 tag_encoder: Seq2SeqEncoder = None,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 context_for_intent: bool = True,
                 context_for_tag: bool = True,
                 attention_for_intent: bool = True,
                 attention_for_tag: bool = True,
                 sequence_label_namespace: str = "labels",
                 intent_label_namespace: str = "intent_labels",
                 feedforward: Optional[FeedForward] = None,
                 label_encoding: Optional[str] = None,
                 include_start_end_transitions: bool = True,
                 crf_decoding: bool = False,
                 constrain_crf_decoding: bool = None,
                 focal_loss_gamma: float = None,
                 nongeneral_intent_weight: float = 5.,
                 num_train_examples: float = None,
                 calculate_span_f1: bool = None,
                 dropout: Optional[float] = None,
                 verbose_metrics: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.context_for_intent = context_for_intent
        self.context_for_tag = context_for_tag
        self.attention_for_intent = attention_for_intent
        self.attention_for_tag = attention_for_tag
        self.sequence_label_namespace = sequence_label_namespace
        self.intent_label_namespace = intent_label_namespace
        self.text_field_embedder = text_field_embedder
        self.num_tags = self.vocab.get_vocab_size(sequence_label_namespace)
        self.num_intents = self.vocab.get_vocab_size(intent_label_namespace)
        self.encoder = encoder
        self.intent_encoder = intent_encoder
        self.tag_encoder = intent_encoder
        self._feedforward = feedforward
        self._verbose_metrics = verbose_metrics
        self.rl = False 
 
        if attention:
            if attention_function:
                raise ConfigurationError("You can only specify an attention module or an "
                                         "attention function, but not both.")
            self.attention = attention
        elif attention_function:
            self.attention = LegacyAttention(attention_function)

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None

        projection_input_dim = feedforward.get_output_dim() if self._feedforward else self.encoder.get_output_dim()
        if self.context_for_intent:
            projection_input_dim += self.encoder.get_output_dim()
        if self.attention_for_intent:
            projection_input_dim += self.encoder.get_output_dim()
        self.intent_projection_layer = Linear(projection_input_dim, self.num_intents)

        if num_train_examples:
            try:
                pos_weight = torch.tensor([log10((num_train_examples - self.vocab._retained_counter[intent_label_namespace][t]) / 
                                self.vocab._retained_counter[intent_label_namespace][t]) for i, t in 
                                self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()])
            except:
                pos_weight = torch.tensor([1. for i, t in 
                                self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()])
        else:
            # pos_weight = torch.tensor([(lambda t: 1. if "general" in t else nongeneral_intent_weight)(t) for i, t in 
            pos_weight = torch.tensor([(lambda t: nongeneral_intent_weight if "Request" in t else 1.)(t) for i, t in 
                            self.vocab.get_index_to_token_vocabulary(intent_label_namespace).items()])
        self.intent_loss = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="none")

        tag_projection_input_dim = feedforward.get_output_dim() if self._feedforward else self.encoder.get_output_dim()
        if self.context_for_tag:
            tag_projection_input_dim += self.encoder.get_output_dim()
        if self.attention_for_tag:
            tag_projection_input_dim += self.encoder.get_output_dim()
        self.tag_projection_layer = TimeDistributed(Linear(tag_projection_input_dim,
                                                           self.num_tags))

        # if  constrain_crf_decoding and calculate_span_f1 are not
        # provided, (i.e., they're None), set them to True
        # if label_encoding is provided and False if it isn't.
        if constrain_crf_decoding is None:
            constrain_crf_decoding = label_encoding is not None
        if calculate_span_f1 is None:
            calculate_span_f1 = label_encoding is not None

        self.label_encoding = label_encoding
        if constrain_crf_decoding:
            if not label_encoding:
                raise ConfigurationError("constrain_crf_decoding is True, but "
                                         "no label_encoding was specified.")
            labels = self.vocab.get_index_to_token_vocabulary(sequence_label_namespace)
            constraints = allowed_transitions(label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = include_start_end_transitions
        if crf_decoding:
            self.crf = ConditionalRandomField(
                    self.num_tags, constraints,
                    include_start_end_transitions=include_start_end_transitions
            )
        else:
            self.crf = None

        self._intent_f1_metric = MultiLabelF1Measure(vocab,
                                                namespace=intent_label_namespace)
        self.calculate_span_f1 = calculate_span_f1
        if calculate_span_f1:
            if not label_encoding:
                raise ConfigurationError("calculate_span_f1 is True, but "
                                          "no label_encoding was specified.")
            self._f1_metric = SpanBasedF1Measure(vocab,
                                                 tag_namespace=sequence_label_namespace,
                                                 label_encoding=label_encoding)
        self._dai_f1_metric = DialogActItemF1Measure()

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        if feedforward is not None:
            check_dimensions_match(encoder.get_output_dim(), feedforward.get_input_dim(),
                                   "encoder output dim", "feedforward input dim")
        initializer(self)
예제 #7
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedding: Embedding,
                 target_embedding: Embedding,
                 encoder: Seq2SeqEncoder,
                 target_namespace: str,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 beam_size: int = None,
                 scheduled_sampling_ratio: float = 0.) -> None:
        super().__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        # Dense embedding of source vocab tokens.
        self._source_embedding = source_embedding

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError(
                    "You can only specify an attention module or an "
                    "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.

        self._target_embedding = target_embedding
        target_embedding_dim = self._target_embedding.get_output_dim()

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder output will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

        # At prediction time, we can use a beam search to find the most likely sequence of target tokens.
        # If the beam_size parameter is not given, we'll just use a greedy search (equivalent to beam_size = 1).
        self._max_decoding_steps = max_decoding_steps
        if beam_size is not None:
            self._beam_search = BeamSearch(self._end_index,
                                           max_steps=max_decoding_steps,
                                           beam_size=beam_size)
        else:
            self._beam_search = None