Ejemplo n.º 1
0
    def forward(self, padded_input, input_lengths, return_attns=False):
        """
        Args:
            padded_input: N x T x D
            input_lengths: N

        Returns:
            enc_output: N x T x H
        """
        enc_slf_attn_list = []

        # Prepare masks
        non_pad_mask = get_non_pad_mask(padded_input,
                                        input_lengths=input_lengths)
        length = padded_input.size(1)
        slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length)

        # Forward
        enc_output = self.dropout(
            self.layer_norm_in(self.linear_in(padded_input)) +
            self.positional_encoding(padded_input))
        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, non_pad_mask,
                                                 slf_attn_mask)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        if return_attns:
            return enc_output, enc_slf_attn_list

        return (enc_output, )
Ejemplo n.º 2
0
    def forward(self, src_seq, return_attns=False):
        '''
        src_seq?
        src_pos? -> for positional embedding
        '''
        enc_slf_attn_list = []
        src_pos = self.get_positional_input(src_seq).to(src_seq.device)
        # -- Prepare masks
        slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq,
                                              seq_q=src_seq)  # seq_k = seq_q ?
        non_pad_mask = get_non_pad_mask(src_seq)

        # -- Forward # use two embedding layer for transform the input?

        #         print('Source:' ,src_seq.shape)
        #         print('Word:', self.src_word_emb(src_seq).shape)
        #         print('Position: ',self.position_enc(src_pos).shape)
        enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output,
                                                 non_pad_mask=non_pad_mask,
                                                 slf_attn_mask=slf_attn_mask)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        if return_attns:
            return enc_output, enc_slf_attn_list  # what's the usage of the attention map?
        return enc_output,
Ejemplo n.º 3
0
    def forward(self, inputs, pos, src_inputs, enc_outputs, return_att=False):
        non_pad_mask = utils.get_non_pad_mask(inputs)

        self_attn_mask_subseq = utils.get_subsequent_mask(inputs)
        self_attn_mask_keypad = utils.get_att_key_pad_mask(seq_k=inputs, seq_q=inputs)
        self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = utils.get_att_key_pad_mask(seq_k=src_inputs, seq_q=inputs)

        self_att_list = []
        enc_att_list = []
        output = self.embed(inputs) + self.position_enc(pos)
        for layer in self.layers:
            output, self_att, enc_att = layer(output,
                                              enc_outputs,
                                              non_pad_mask=non_pad_mask,
                                              input_att_mask=self_attn_mask,
                                              enc_att_mask=dec_enc_attn_mask,
                                             )            
            if return_att:
                self_att_list.append(self_att)
                enc_att_list.append(enc_att)

        output = self.proj(output)
        return output
	def forward_greedy(self,z,num_steps,temperature,x=None):
		predictions = []
		batch_size = z.size(0)
		next_input = z.new_zeros(size=(batch_size,num_steps),dtype=torch.long,requires_grad=False)
		next_input[:,:] = self.PAD_TOKEN
		next_input[:,0] = self.SOS_TOKEN # <sos> token
		z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(1,num_steps,1)
		for i in range(num_steps):
			input = next_input
			step_input = self.embedding(input)
			step_input = self.pos_embedding(step_input)
			step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size]
			step_input = self.activation(self.s2h(step_input))
			non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN)
			slf_attn_mask_subseq = get_subsequent_mask(input)
			slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN)
			attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
			out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask)
			out = out[:,i,:]
			out = self.activation(out)
			out = self.h2o(out)
			out = self.last_activation(out,temperature)
			if x is not None: # teacher forcing
				previous_output = x[:,i]
			else: # use prediction as input
				previous_output = torch.argmax(out,dim=-1)
				previous_output = previous_output.detach()
			next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1).detach()
			predictions.append(out)
		output = torch.stack(predictions).transpose(1,0)
		return output
Ejemplo n.º 5
0
    def forward(self,
                padded_input,
                encoder_padded_outputs,
                encoder_input_lengths,
                return_attns=False):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H

        Returns:
        """
        dec_slf_attn_list, dec_enc_attn_list = [], []

        # Get Deocder Input and Output
        ys_in_pad, ys_out_pad = self.preprocess(padded_input)

        # Prepare masks
        non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id)

        slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
                                                     seq_q=ys_in_pad,
                                                     pad_idx=self.eos_id)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        output_length = ys_in_pad.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
                                              encoder_input_lengths,
                                              output_length)

        # Forward
        dec_output = self.dropout(
            self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
            self.positional_encoding(ys_in_pad))

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output,
                encoder_padded_outputs,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        # before softmax
        seq_logit = self.tgt_word_prj(dec_output)

        # Return
        pred, gold = seq_logit, ys_out_pad

        if return_attns:
            return pred, gold, dec_slf_attn_list, dec_enc_attn_list
        return pred, gold
Ejemplo n.º 6
0
    def forward(self, inputs, pos, return_att=False):
        self_att_mask = utils.get_att_key_pad_mask(seq_k=inputs, seq_q=inputs)
        non_pad_mask = utils.get_non_pad_mask(inputs)

        att_weights_list = []
        output = self.embed(inputs)
        output = output + self.position_enc(pos)
        for enc_layer in self.layers:
            output, att_weights = enc_layer(output, non_pad_mask=non_pad_mask, mh_att_mask=self_att_mask)
            if return_att:
                att_weights_list.append(att_weights)

        return output
	def forward_beam(self,z,num_steps,temperature):
		predictions = []
		batch_size = z.size(0)
		next_input = z.new_zeros(size=(batch_size*self.beam_width,num_steps),dtype=torch.long,requires_grad=False)
		next_input[:,:] = self.PAD_TOKEN
		next_input[:,0] = self.SOS_TOKEN # <sos> token
		z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(self.beam_width,num_steps,1)
		previous_output = z.new_zeros(size=(batch_size*self.beam_width,),dtype=torch.long)
		previous_output[:] = self.SOS_TOKEN # <sos> token
		# a table for storing the scores
		scores = z.new_zeros(size=(batch_size*self.beam_width,self.output_size))
		# an array of numbers for displacement ie. if batch_size is 2 and beam_width is 3 then this is [0,0,0,3,3,3]. This is used later for indexing
		beam_displacement = torch.arange(start=0,end=batch_size*self.beam_width,step=self.beam_width,dtype=torch.long,device=z.device).view(-1,1).repeat(1,self.beam_width).view(-1)
		for i in range(num_steps):
			input = next_input.detach()
			step_input = self.embedding(input)
			step_input = self.pos_embedding(step_input)
			step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size]
			step_input = self.activation(self.s2h(step_input))
			non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN)
			slf_attn_mask_subseq = get_subsequent_mask(input)
			slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN)
			attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
			out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask)
			out = out[:,i,:]
			out = self.activation(out)
			out = self.h2o(out)
			out = self.last_activation(out,temperature)
			# compute new scores
			next_scores = scores + torch.log(out+1e-8)
			# select top-k scores where k is the beam width
			score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1)
			# flatten the output
			outputs = outputs.view(-1)
			# get the indices in the original onehot output by finding the module of the vocab size
			indices = torch.fmod(outputs,self.output_size)
			# find the index in the beam/batch for the onehot output. Add beam displacement to get correct index
			beam_indices = torch.div(outputs,self.output_size) + beam_displacement
			# check if some elements/words are repeated
			res = torch.eq(previous_output,indices).nonzero()
			# some elements/words is repeated
			retries = 0
			while res.shape[0] > 0:
				mask = torch.ones(size=(batch_size*self.beam_width,self.output_size),requires_grad=False,device=z.device)
				# set the mask to be zero when an option is non selectable
				mask[beam_indices[res],indices[res]] = 0
				# apply the mask
				out = out * mask
				# set the score for the repeated elements to be low
				next_scores = scores + torch.log(out+1e-8)
				# select top-k scores where k is the beam width
				score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1)
				# flatten the output
				outputs = outputs.view(-1)
				# get the indices in the original onehot output by finding the module of the vocab size
				indices = torch.fmod(outputs,self.output_size)
				# find the index in the beam/batch for the onehot output. Add beam displacement to get correct index
				beam_indices = torch.div(outputs,self.output_size) + beam_displacement
				# check if some elements/words are repeated
				res = torch.eq(previous_output,indices).nonzero()
				if retries > 10:
					break
				retries += 1
			# copy the score for each selected candidate
			scores = score.view(-1,1).repeat(1,self.output_size)
			# renormalize the output
			out = out/out.sum(-1).view(-1,1).repeat(1,self.output_size)
			# append the prediction to output
			predictions.append(out[beam_indices,:])
			# detach the output such that we don't backpropagate through timesteps
			previous_output = indices.detach()
			next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1)
		output = torch.stack(predictions).transpose(1,0)
		# initialize an output_mask such that we can filter out sentences
		output_mask = torch.zeros_like(output)
		# set the selected sentences output_mask to 1
		output_mask[scores[:,0].view(batch_size,-1).argmax(dim=-1) + beam_displacement.view(batch_size,-1)[:,0]] = 1
		# collect the best prediction for each sample in batch
		output = (output*output_mask).view(batch_size,self.beam_width,num_steps,self.output_size)
		# sum the beam sentences. Since the sentences that is not selected is zero this doesn't change the actual sentences
		output = output.sum(1)
		return output