def initialize_output_tokens(self, encoder_out, prev_output_tokens): # length prediction src_lengths = (~encoder_out.encoder_padding_mask).sum(1, keepdim=True) length_tgt = (src_lengths * 1.2 + 10).long() max_length_tgt = length_tgt.max() bsz = prev_output_tokens.size(0) initial_output_tokens = prev_output_tokens.new_ones( [bsz, max_length_tgt]) * self.pad arange_mat = torch.arange( max_length_tgt, device=prev_output_tokens.device).squeeze(0).repeat([bsz, 1]) initial_output_tokens = initial_output_tokens.masked_fill( arange_mat < length_tgt - 1, self.unk) initial_output_tokens = initial_output_tokens.masked_fill( arange_mat.eq(length_tgt - 1), self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return DecoderOut(output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=10000, history=None)
def initialize_output_tokens(self, encoder_out, src_tokens): # length prediction length_tgt = self.decoder.forward_length_prediction( self.decoder.forward_length(normalize=True, encoder_out=encoder_out), encoder_out=encoder_out, ) max_length = length_tgt.clamp_(min=2).max() idx_length = utils.new_arange(src_tokens, max_length) initial_output_tokens = src_tokens.new_zeros( src_tokens.size(0), max_length).fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk) initial_output_tokens[:, 0] = self.bos initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return DecoderOut( output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=0, history=None, )
def initialize_output_tokens(self, encoder_out, src_tokens): # length prediction enc_feats = encoder_out.encoder_out # T x B x C src_masks = encoder_out.encoder_padding_mask # B x T or None if src_masks is None: src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( enc_feats.size(0)) else: src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0) src_lengs = src_lengs.long() length_tgt = src_lengs * 2 max_length = length_tgt.clamp_(min=2).max() idx_length = utils.new_arange(src_tokens, max_length) initial_output_tokens = src_tokens.new_zeros( src_tokens.size(0), max_length).fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk) initial_output_tokens[:, 0] = self.bos initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return DecoderOut( output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=0, history=None, )
def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2) initial_output_tokens[:, 0] = self.bos initial_output_tokens[:, 1] = self.eos initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return DecoderOut(output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=0, history=None)
def initialize_output_tokens(self, encoder_out, src_tokens, tgt_lang): # length prediction length_tgt = self.decoder.forward_length_prediction( self.decoder.forward_length(normalize=True, encoder_out=encoder_out), tgt_lengths=None) print("predict tgt_lengths: ", length_tgt) max_length = length_tgt.clamp_(min=2).max() idx_length = utils.new_arange(src_tokens, max_length) positions = torch.arange(1, max_length + 1)[None, :].repeat( src_tokens.size(0), 1).to(src_tokens.device) positions.masked_fill_(idx_length[None, :] + 1 > length_tgt[:, None], 0) initial_output_tokens = src_tokens.new_zeros( src_tokens.size(0), max_length).long().fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk) initial_output_tokens[:, 0] = self.bos if tgt_lang == "en": initial_output_tokens[:, 1] = self.en_tag langs = src_tokens.new_zeros(src_tokens.size(0), max_length).long() elif tgt_lang == "ch": initial_output_tokens[:, 1] = self.ch_tag langs = src_tokens.new_ones(src_tokens.size(0), max_length).long() else: assert tgt_lang == ("en", "ch") pass initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return langs, positions, DecoderOut( output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=0, history=None)
def initialize_output_tokens(self, encoder_out, src_tokens, initial_tokens, initial_marks): initial_tokens = initial_tokens.tolist( ) if initial_tokens is not None else None initial_marks = initial_marks.tolist( ) if initial_marks is not None else None max_num_constraints = max([len(seq) for seq in initial_tokens ]) if initial_tokens else 0 initial_output_marks = src_tokens.new_zeros(src_tokens.size(0), max_num_constraints + 2) initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), max_num_constraints + 2) initial_output_tokens[:, 0] = self.bos initial_output_tokens[:, 1] = self.eos if initial_tokens: for i, seq in enumerate(initial_tokens): for j, tok in enumerate(seq): initial_output_tokens[i, j + 1] = tok initial_output_tokens[i, len(seq) + 1] = self.eos for j in range(len(seq) + 2, max_num_constraints + 2): initial_output_tokens[i, j] = self.pad if initial_marks: for i, seq in enumerate(initial_marks): for j, mark in enumerate(seq): initial_output_marks[i, j + 1] = mark initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size()).type_as(encoder_out.encoder_out) return DecoderOut(output_tokens=initial_output_tokens, output_marks=initial_output_marks, output_scores=initial_output_scores, attn=None, step=0, max_step=0, num_ops=(0, 0), history=None)
def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_tokens = torch.cat( [ src_tokens.new_zeros(src_tokens.size(0), 1).fill_(self.bos), src_tokens.new_zeros(src_tokens.size(0), 1).fill_(self.eos), ], 1, ) initial_output_scores = torch.zeros_like(initial_output_tokens).to( encoder_out.encoder_out) initial_attn = torch.empty([0]) if getattr(self.decoder.layers[-1], "need_attn", True): initial_attn = torch.zeros( [src_tokens.size(0), 2, src_tokens.size(1)]).to(initial_output_tokens) return DecoderOut(output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=initial_attn, step=0, max_step=0, history=None)