Ejemplo n.º 1
0
    def forward(self,
                inputs,
                listener_outputs,
                function=F.log_softmax,
                teacher_forcing_ratio=0.90,
                use_beam_search=False):
        batch_size = inputs.size(0)
        max_length = inputs.size(1) - 1  # minus the start of sequence symbol

        decode_results = list()
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        hidden = self._init_state(batch_size)

        if use_beam_search:  # TopK Decoding
            input_ = inputs[:, 0].unsqueeze(1)
            beam = Beam(k=self.k,
                        decoder=self,
                        batch_size=batch_size,
                        max_length=max_length,
                        function=function,
                        device=self.device)
            logits = None
            y_hats = beam.search(input_, listener_outputs)

        else:
            if use_teacher_forcing:  # if teacher_forcing, Infer all at once
                speller_inputs = inputs[inputs != self.eos_id].view(
                    batch_size, -1)
                predicted_softmax, hidden = self.forward_step(
                    input_=speller_inputs,
                    hidden=hidden,
                    listener_outputs=listener_outputs,
                    function=function)

                for di in range(predicted_softmax.size(1)):
                    step_output = predicted_softmax[:, di, :]
                    decode_results.append(step_output)

            else:
                speller_input = inputs[:, 0].unsqueeze(1)

                for di in range(max_length):
                    predicted_softmax, hidden = self.forward_step(
                        input_=speller_input,
                        hidden=hidden,
                        listener_outputs=listener_outputs,
                        function=function)
                    step_output = predicted_softmax.squeeze(1)
                    decode_results.append(step_output)
                    speller_input = decode_results[-1].topk(1)[1]

            logits = torch.stack(decode_results, dim=1).to(self.device)
            y_hats = logits.max(-1)[1]

        return y_hats, logits
Ejemplo n.º 2
0
    def forward(self,
                inputs,
                encoder_outputs,
                function=F.log_softmax,
                teacher_forcing_ratio=0.90,
                use_beam_search=False):
        y_hats, logits = None, None
        decode_results = []
        batch_size = inputs.size(0)
        max_len = inputs.size(1) - 1  # minus the start of sequence symbol
        decoder_hidden = torch.FloatTensor(self.n_layers, batch_size,
                                           self.hidden_size).uniform_(
                                               -0.1, 0.1).to(self.device)
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        if use_beam_search:
            """ Beam-Search Decoding """
            inputs = inputs[:, 0].unsqueeze(1)
            beam = Beam(k=self.k,
                        decoder_hidden=decoder_hidden,
                        decoder=self,
                        batch_size=batch_size,
                        max_len=max_len,
                        function=function,
                        device=self.device)
            y_hats = beam.search(inputs, encoder_outputs)
        else:
            if use_teacher_forcing:
                """ if teacher_forcing, Infer all at once """
                inputs = inputs[:, :-1]
                predicted_softmax = self._forward_step(
                    input=inputs,
                    decoder_hidden=decoder_hidden,
                    encoder_outputs=encoder_outputs,
                    function=function)
                for di in range(predicted_softmax.size(1)):
                    step_output = predicted_softmax[:, di, :]
                    decode_results.append(step_output)
            else:
                input = inputs[:, 0].unsqueeze(1)
                for di in range(max_len):
                    predicted_softmax = self._forward_step(
                        input=input,
                        decoder_hidden=decoder_hidden,
                        encoder_outputs=encoder_outputs,
                        function=function)
                    step_output = predicted_softmax.squeeze(1)
                    decode_results.append(step_output)
                    input = decode_results[-1].topk(1)[1]

            logits = torch.stack(decode_results, dim=1).to(self.device)
            y_hats = logits.max(-1)[1]
        return y_hats, logits
    def forward(self,
                inputs=None,
                listener_hidden=None,
                listener_outputs=None,
                function=F.log_softmax,
                teacher_forcing_ratio=0.99):
        y_hats, logit = None, None
        decode_results = []
        # Validate Arguments
        batch_size = inputs.size(0)
        max_length = inputs.size(1) - 1  # minus the start of sequence symbol
        # Initiate Speller Hidden State to zeros  :  LxBxH
        speller_hidden = torch.FloatTensor(self.layer_size, batch_size,
                                           self.hidden_size).uniform_(
                                               -1.0, 1.0)  #.cuda()
        # Decide Use Teacher Forcing or Not
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        if self.use_beam_search:
            """Implementation of Beam-Search Decoding"""
            speller_input = inputs[:, 0].unsqueeze(1)
            beam = Beam(k=self.k,
                        speller_hidden=speller_hidden,
                        decoder=self,
                        batch_size=batch_size,
                        max_len=max_length,
                        decode_func=function)
            y_hats = beam.search(speller_input, listener_outputs)
        else:
            # Manual unrolling is used to support random teacher forcing.
            # If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph
            if use_teacher_forcing:
                speller_input = inputs[:, :-1]  # except </s>
                """ if teacher_forcing, Infer all at once """
                predicted_softmax = self._forward_step(speller_input,
                                                       speller_hidden,
                                                       listener_outputs,
                                                       function=function)
                """Extract Output by Step"""
                for di in range(predicted_softmax.size(1)):
                    step_output = predicted_softmax[:, di, :]
                    decode_results.append(step_output)
            else:
                speller_input = inputs[:, 0].unsqueeze(1)
                for di in range(max_length):
                    predicted_softmax = self._forward_step(speller_input,
                                                           speller_hidden,
                                                           listener_outputs,
                                                           function=function)
                    # (batch_size, classfication_num)
                    step_output = predicted_softmax.squeeze(1)
                    decode_results.append(step_output)
                    speller_input = decode_results[-1].topk(1)[1]

            logit = torch.stack(decode_results, dim=1).to(self.device)
            y_hats = logit.max(-1)[1]
        print("Speller y_hats ====================")
        print(y_hats)

        return y_hats, logit if self.training else y_hats