def forward_greedy(self,z,num_steps,temperature,x=None):
		predictions = []
		batch_size = z.size(0)
		next_input = z.new_zeros(size=(batch_size,num_steps),dtype=torch.long,requires_grad=False)
		next_input[:,:] = self.PAD_TOKEN
		next_input[:,0] = self.SOS_TOKEN # <sos> token
		z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(1,num_steps,1)
		for i in range(num_steps):
			input = next_input
			step_input = self.embedding(input)
			step_input = self.pos_embedding(step_input)
			step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size]
			step_input = self.activation(self.s2h(step_input))
			non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN)
			slf_attn_mask_subseq = get_subsequent_mask(input)
			slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN)
			attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
			out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask)
			out = out[:,i,:]
			out = self.activation(out)
			out = self.h2o(out)
			out = self.last_activation(out,temperature)
			if x is not None: # teacher forcing
				previous_output = x[:,i]
			else: # use prediction as input
				previous_output = torch.argmax(out,dim=-1)
				previous_output = previous_output.detach()
			next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1).detach()
			predictions.append(out)
		output = torch.stack(predictions).transpose(1,0)
		return output
Esempio n. 2
0
    def forward(self, inputs, pos, src_inputs, enc_outputs, return_att=False):
        non_pad_mask = utils.get_non_pad_mask(inputs)

        self_attn_mask_subseq = utils.get_subsequent_mask(inputs)
        self_attn_mask_keypad = utils.get_att_key_pad_mask(seq_k=inputs, seq_q=inputs)
        self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = utils.get_att_key_pad_mask(seq_k=src_inputs, seq_q=inputs)

        self_att_list = []
        enc_att_list = []
        output = self.embed(inputs) + self.position_enc(pos)
        for layer in self.layers:
            output, self_att, enc_att = layer(output,
                                              enc_outputs,
                                              non_pad_mask=non_pad_mask,
                                              input_att_mask=self_attn_mask,
                                              enc_att_mask=dec_enc_attn_mask,
                                             )            
            if return_att:
                self_att_list.append(self_att)
                enc_att_list.append(enc_att)

        output = self.proj(output)
        return output
Esempio n. 3
0
    def forward(self,
                padded_input,
                encoder_padded_outputs,
                encoder_input_lengths,
                return_attns=False):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H

        Returns:
        """
        dec_slf_attn_list, dec_enc_attn_list = [], []

        # Get Deocder Input and Output
        ys_in_pad, ys_out_pad = self.preprocess(padded_input)

        # Prepare masks
        non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id)

        slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
                                                     seq_q=ys_in_pad,
                                                     pad_idx=self.eos_id)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        output_length = ys_in_pad.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
                                              encoder_input_lengths,
                                              output_length)

        # Forward
        dec_output = self.dropout(
            self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
            self.positional_encoding(ys_in_pad))

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output,
                encoder_padded_outputs,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        # before softmax
        seq_logit = self.tgt_word_prj(dec_output)

        # Return
        pred, gold = seq_logit, ys_out_pad

        if return_attns:
            return pred, gold, dec_slf_attn_list, dec_enc_attn_list
        return pred, gold
Esempio n. 4
0
            def predict_word(dec_seq, src_seq, enc_output, n_active_inst, n_bm):
                src_mask = get_pad_mask(src_seq, PAD)
                dec_mask = get_pad_mask(dec_seq, PAD) & get_subsequent_mask(dec_seq)
                dec_output, *_ = self.model.decoder(dec_seq, dec_mask, enc_output, src_mask)
                                                    
                dec_output = dec_output[:, -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.model.trg_word_prj(dec_output), dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob
Esempio n. 5
0
    def forward(self, src_seq, trg_seq):

        src_mask = get_pad_mask(src_seq, self.src_pad_idx)
        trg_mask = get_pad_mask(
            trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)

        enc_output, *_ = self.encoder(src_seq, src_mask)
        dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
        seq_logit = self.trg_word_prj(dec_output) * self.x_logit_scale

        # return seq_logit.view(-1, seq_logit.size(2))
        return seq_logit
Esempio n. 6
0
def decoder(decoder_model, model_cls, decode_input_ids, encode_outputs,
            encode_attention_mask):
    decode_attention_mask = get_subsequent_mask(decode_input_ids)
    decode_outputs, *_ = decoder_model(
        decode_input_ids,
        decode_attention_mask,
        encoder_output=encode_outputs,
        encoder_attention_mask=encode_attention_mask)

    logits = model_cls(decode_outputs)

    return logits
Esempio n. 7
0
    def recognize_beam(self, encoder_outputs, char_list, args):
        """Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H
            char_list: list of character
            args: args.beam

        Returns:
            nbest_hyps:
        """
        # search params
        beam = args.beam_size
        nbest = args.nbest
        if args.decode_max_len == 0:
            maxlen = encoder_outputs.size(0)
        else:
            maxlen = args.decode_max_len

        encoder_outputs = encoder_outputs.unsqueeze(0)

        # prepare sos
        ys = torch.ones(1,
                        1).fill_(self.sos_id).type_as(encoder_outputs).long()

        # yseq: 1xT
        hyp = {'score': 0.0, 'yseq': ys}
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                ys = hyp['yseq']  # 1 x i

                # -- Prepare masks
                non_pad_mask = torch.ones_like(ys).float().unsqueeze(
                    -1)  # 1xix1
                slf_attn_mask = get_subsequent_mask(ys)

                # -- Forward
                dec_output = self.dropout(
                    self.tgt_word_emb(ys) * self.x_logit_scale +
                    self.positional_encoding(ys))

                for dec_layer in self.layer_stack:
                    dec_output, _, _ = dec_layer(dec_output,
                                                 encoder_outputs,
                                                 non_pad_mask=non_pad_mask,
                                                 slf_attn_mask=slf_attn_mask,
                                                 dec_enc_attn_mask=None)

                seq_logit = self.tgt_word_prj(dec_output[:, -1])

                local_scores = F.log_softmax(seq_logit, dim=1)
                # topk scores
                local_best_scores, local_best_ids = torch.topk(local_scores,
                                                               beam,
                                                               dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                    new_hyp['yseq'] = torch.ones(
                        1, (1 + ys.size(1))).type_as(encoder_outputs).long()
                    new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
                    new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp['yseq'] = torch.cat([
                        hyp['yseq'],
                        torch.ones(1, 1).fill_(
                            self.eos_id).type_as(encoder_outputs).long()
                    ],
                                            dim=1)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][0, -1] == self.eos_id:
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            if len(hyps) > 0:
                print('remeined hypothes: ' + str(len(hyps)))
            else:
                print('no hypothesis. Finish decoding.')
                break

            for hyp in hyps:
                # print('fch: ')
                # print(str(x.encode('utf-8')) for x in char_list)
                # print(([(char_list[int(x)]).encode('utf-8') for x in hyp['yseq'][0, 1:]]))
                # print('hypo: ' + ''.join([str(char_list[int(x)].encode('utf-8'))
                # for x in hyp['yseq'][0, 1:]]))
                print('hypo: ' +
                      ''.join([char_list[int(x)] for x in hyp['yseq'][0, 1:]]))
        # end for i in range(maxlen)
        nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'],
                            reverse=True)[:min(len(ended_hyps), nbest)]
        # compitable with LAS implementation
        for hyp in nbest_hyps:
            hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
        return nbest_hyps
	def forward_beam(self,z,num_steps,temperature):
		predictions = []
		batch_size = z.size(0)
		next_input = z.new_zeros(size=(batch_size*self.beam_width,num_steps),dtype=torch.long,requires_grad=False)
		next_input[:,:] = self.PAD_TOKEN
		next_input[:,0] = self.SOS_TOKEN # <sos> token
		z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(self.beam_width,num_steps,1)
		previous_output = z.new_zeros(size=(batch_size*self.beam_width,),dtype=torch.long)
		previous_output[:] = self.SOS_TOKEN # <sos> token
		# a table for storing the scores
		scores = z.new_zeros(size=(batch_size*self.beam_width,self.output_size))
		# an array of numbers for displacement ie. if batch_size is 2 and beam_width is 3 then this is [0,0,0,3,3,3]. This is used later for indexing
		beam_displacement = torch.arange(start=0,end=batch_size*self.beam_width,step=self.beam_width,dtype=torch.long,device=z.device).view(-1,1).repeat(1,self.beam_width).view(-1)
		for i in range(num_steps):
			input = next_input.detach()
			step_input = self.embedding(input)
			step_input = self.pos_embedding(step_input)
			step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size]
			step_input = self.activation(self.s2h(step_input))
			non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN)
			slf_attn_mask_subseq = get_subsequent_mask(input)
			slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN)
			attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
			out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask)
			out = out[:,i,:]
			out = self.activation(out)
			out = self.h2o(out)
			out = self.last_activation(out,temperature)
			# compute new scores
			next_scores = scores + torch.log(out+1e-8)
			# select top-k scores where k is the beam width
			score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1)
			# flatten the output
			outputs = outputs.view(-1)
			# get the indices in the original onehot output by finding the module of the vocab size
			indices = torch.fmod(outputs,self.output_size)
			# find the index in the beam/batch for the onehot output. Add beam displacement to get correct index
			beam_indices = torch.div(outputs,self.output_size) + beam_displacement
			# check if some elements/words are repeated
			res = torch.eq(previous_output,indices).nonzero()
			# some elements/words is repeated
			retries = 0
			while res.shape[0] > 0:
				mask = torch.ones(size=(batch_size*self.beam_width,self.output_size),requires_grad=False,device=z.device)
				# set the mask to be zero when an option is non selectable
				mask[beam_indices[res],indices[res]] = 0
				# apply the mask
				out = out * mask
				# set the score for the repeated elements to be low
				next_scores = scores + torch.log(out+1e-8)
				# select top-k scores where k is the beam width
				score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1)
				# flatten the output
				outputs = outputs.view(-1)
				# get the indices in the original onehot output by finding the module of the vocab size
				indices = torch.fmod(outputs,self.output_size)
				# find the index in the beam/batch for the onehot output. Add beam displacement to get correct index
				beam_indices = torch.div(outputs,self.output_size) + beam_displacement
				# check if some elements/words are repeated
				res = torch.eq(previous_output,indices).nonzero()
				if retries > 10:
					break
				retries += 1
			# copy the score for each selected candidate
			scores = score.view(-1,1).repeat(1,self.output_size)
			# renormalize the output
			out = out/out.sum(-1).view(-1,1).repeat(1,self.output_size)
			# append the prediction to output
			predictions.append(out[beam_indices,:])
			# detach the output such that we don't backpropagate through timesteps
			previous_output = indices.detach()
			next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1)
		output = torch.stack(predictions).transpose(1,0)
		# initialize an output_mask such that we can filter out sentences
		output_mask = torch.zeros_like(output)
		# set the selected sentences output_mask to 1
		output_mask[scores[:,0].view(batch_size,-1).argmax(dim=-1) + beam_displacement.view(batch_size,-1)[:,0]] = 1
		# collect the best prediction for each sample in batch
		output = (output*output_mask).view(batch_size,self.beam_width,num_steps,self.output_size)
		# sum the beam sentences. Since the sentences that is not selected is zero this doesn't change the actual sentences
		output = output.sum(1)
		return output
Esempio n. 9
0
    def recognize_beam(self, encoder_outputs, char_list, args):
        """
        Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H #418 x 512
            char_list: list of character #4233
            args: args.beam #5

        Returns:
            nbest_hyps:
        """
        # search params
        beam = args.beam_size
        nbest = args.nbest
        if args.decode_max_len == 0:
            maxlen = encoder_outputs.size(0)
        else:
            maxlen = args.decode_max_len

        encoder_outputs = encoder_outputs.unsqueeze(0)
        # prepare sos
        ys = flow.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
        hyp = {"score": 0.0, "yseq": ys}
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                ys = hyp["yseq"]
                ys = ys.to(device=encoder_outputs.device)
                # -- Prepare masks
                non_pad_mask = flow.ones_like(ys).to(
                    dtype=flow.float32).unsqueeze(-1)
                slf_attn_mask = get_subsequent_mask(ys)
                # -- Forward
                dec_output = self.dropout(
                    self.tgt_word_emb(ys) * self.x_logit_scale +
                    self.positional_encoding(ys))

                for dec_layer in self.layer_stack:
                    dec_output, _, _ = dec_layer(
                        dec_output,
                        encoder_outputs,
                        non_pad_mask=non_pad_mask,
                        slf_attn_mask=slf_attn_mask,
                        dec_enc_attn_mask=None,
                    )

                seq_logit = self.tgt_word_prj(dec_output[:, -1])
                local_logit = F.softmax(seq_logit)
                local_scores = flow.log(local_logit)
                # topk scores
                local_best_scores, local_best_ids = flow.topk(local_scores,
                                                              beam,
                                                              dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
                    new_hyp["yseq"] = (flow.ones(
                        1, (1 + ys.size(1))).type_as(encoder_outputs).long())
                    new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"]
                    new_hyp["yseq"][:, ys.size(1)] = int(
                        float(local_best_ids[0, j].numpy()))
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x["score"],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept
            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp["yseq"] = flow.cat(
                        [
                            hyp["yseq"],
                            flow.ones(1, 1).fill_(
                                self.eos_id).type_as(encoder_outputs).long(),
                        ],
                        dim=1,
                    )

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp["yseq"][0, -1] == self.eos_id:
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            if len(hyps) > 0:
                print("remeined hypothes: " + str(len(hyps)))
            else:
                print("no hypothesis. Finish decoding.")
                break
            for hyp in hyps:
                print("hypo: " + "".join(
                    [char_list[int(x.numpy())] for x in hyp["yseq"][0, 1:]]))

        nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"],
                            reverse=True)[:min(len(ended_hyps), nbest)]
        for hyp in nbest_hyps:
            hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist()
        return nbest_hyps