def process_embedding(self, input, input_lang=None): # if self.switchout == 0: # input_ = input # if self.switchout > 0 and self.training: # vocab_size = self.word_lut.weight.size(0) # input_ = switchout(input, vocab_size, self.switchout) # else: input_ = input emb = embedded_dropout( self.word_lut, input_, dropout=self.word_dropout if self.training else 0) if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if self.use_language_embedding: lang_emb = self.language_embeddings(input_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb elif self.language_embedding_type == 'concat': # replace the bos embedding with the language bos_emb = lang_emb.expand_as(emb[:, 0, :]) emb[:, 0, :] = bos_emb lang_emb = lang_emb.unsqueeze(1).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError return emb
def forward(self, input, **kwargs): """ Inputs Shapes: input: (Variable) len_tgt x batch_size Outputs Shapes: out: len_tgt x batch_size x d_model """ emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) emb = self.preprocess_layer(emb) if self.h is None: lstm_mem = None else: lstm_mem = (self.h.detach(), self.c.detach()) output, (h, c) = self.rnn(emb, lstm_mem) output = self.postprocess_layer(output) output_dict = defaultdict(lambda: None) output_dict['hidden'] = output output_dict['lstm_mem'] = (h, c) self.h = h self.c = c return output_dict
def forward(self, input): """ Inputs Shapes: input: batch_size x len_src (wanna tranpose) Outputs Shapes: out: batch_size x len_src x d_model mask_src """ """ Embedding: batch_size x len_src x d_model """ emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) """ Scale the emb by sqrt(d_model) """ if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if isinstance(emb, tuple): emb = emb[0] emb = self.preprocess_layer(emb) mask_src = input.data.eq(onmt.constants.PAD).unsqueeze( 1) # batch_size x len_src x 1 for broadcasting pad_mask = torch.autograd.Variable(input.data.ne( onmt.constants.PAD)) # batch_size x len_src #~ pad_mask = None context = emb.contiguous() memory_bank = None for i, layer in enumerate(self.layer_modules): if len(self.layer_modules ) - i <= onmt.constants.checkpointing and self.training: context, memory_bank = checkpoint(custom_layer(layer), context, memory_bank, mask_src, pad_mask) #~ print(type(context)) else: context, memory_bank = layer( context, memory_bank, mask_src, pad_mask) # batch_size x len_src x d_model # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. context = self.postprocess_layer(context) # make a huge memory bank on the encoder side memory_bank = torch.cat([memory_bank, context.unsqueeze(0)], dim=0) return memory_bank, mask_src
def encode(self, input, decoder_state, input_pos=None, input_lang=None): buffers = decoder_state.attention_buffers src_lang = input_lang input = input.transpose(0, 1) # Embedding stage (and scale the embedding) src_emb = embedded_dropout(self.src_embedding, input, dropout=self.word_dropout if self.training else 0) \ * math.sqrt(self.model_size) if self.use_language_embedding: if self.language_embedding_type in ["sum", "all_sum"]: src_lang_emb = self.language_embeddings(src_lang) src_emb += src_lang_emb emb = src_emb src_len = input.size(0) bsz = input.size(1) mask_src_src = input.eq(onmt.constants.PAD).expand(src_len, src_len, bsz) buffer = buffers[0] if 0 in buffers else None if buffer is not None: mem_len = buffer['k'].size(0) else: mem_len = 0 if mem_len > 0: # print(mask_src_src.size()) past_mask = input.new_zeros(src_len, mem_len).bool().unsqueeze(-1).expand(src_len, mem_len, bsz) mask_src_src = torch.cat([past_mask, mask_src_src], dim=1) mask_src = mask_src_src attn_mask = mask_src.bool() # L x L x batch_size output = emb klen = src_len + mem_len pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype) pos_emb = self.positional_encoder(pos) # FORWARD PASS coverage = None for i, layer in enumerate(self.layer_modules): # context and context_mask are None buffer = buffers[i] if i in buffers else None # if i == 0 and buffer is not None: # key = next(iter(buffer)) # print(buffer[key].size()) # output, coverage, buffer = layer.step(output, None, attn_mask, None, buffer) output, coverage, buffer = layer(output, None, pos_emb, attn_mask, None, incremental=True, incremental_cache=buffer) decoder_state.update_attention_buffer(buffer, i) # Final normalization output = self.postprocess_layer(output) return output, decoder_state
def encode(self, input, decoder_state, input_pos=None, input_lang=None): buffers = decoder_state.attention_buffers src_lang = input_lang input = input.transpose(0, 1) # Embedding stage (and scale the embedding) src_emb = embedded_dropout(self.src_embedding, input, dropout=self.word_dropout if self.training else 0) \ * math.sqrt(self.model_size) if self.use_language_embedding: if self.language_embedding_type in ["sum", "all_sum"]: src_lang_emb = self.language_embeddings(src_lang) src_emb += src_lang_emb emb = src_emb src_len = input.size(0) bsz = input.size(1) mask_src_src = input.eq(onmt.constants.PAD).byte() # B x 1 x src_len mask_src = mask_src_src.unsqueeze(0) attn_mask = mask_src.bool() # L x L x batch_size output = emb # Applying dropout and tranpose to T x B x H output = self.preprocess_layer(output) klen = src_len pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype) pos_emb = self.positional_encoder(pos) # FORWARD PASS coverage = None for i, layer in enumerate(self.layer_modules): # context and context_mask are None buffer = buffers[i] if i in buffers else None # output, coverage, buffer = layer.step(output, None, attn_mask, None, buffer) output, coverage, buffer = layer(output, None, pos_emb, attn_mask, None, incremental=True, incremental_cache=buffer) decoder_state.update_attention_buffer(buffer, i) # Final normalization output = self.postprocess_layer(output) return output, decoder_state
def process_embedding(self, input, input_lang=None): input_ = input emb = embedded_dropout( self.word_lut, input_, dropout=self.word_dropout if self.training else 0) if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if self.use_language_embedding: lang_emb = self.language_embeddings(input_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb.unsqueeze(1) elif self.language_embedding_type == 'concat': lang_emb = lang_emb.unsqueeze(1).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError return emb
def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ if self.input_type == "text": bsz_first_input = input input = input.transpose(0, 1) # mask_src = input.eq(onmt.constants.PAD).unsqueeze(1) # batch_size x src_len x 1 for broadcasting dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) if streaming: streaming_state = kwargs.get('streaming_state', None) mems = streaming_state.src_mems # mem_len = streaming_state.src_mems[0].size(0) # mem_len = streaming_state.prev_src_mem_size mem_len = mems[0].size(0) if mems is not None else 0 input_length = kwargs.get('src_lengths', None) streaming_state = kwargs.get('streaming_state', None) mask_src = self.create_stream_mask(input, input_length, mem_len) mask_src = mask_src.unsqueeze(2) else: mem_len = 0 mask_src = input.eq(onmt.constants.PAD).unsqueeze(0) # batch_size x src_len x 1 for broadcasting mems = None emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) # print(lang_emb.size(), emb.size()) emb = emb + lang_emb.unsqueeze(0) else: if streaming: raise NotImplementedError if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq(onmt.constants.PAD).unsqueeze(0) dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view(-1, input.size(2))).view(input.size(0), input.size(1), -1) emb = emb.type_as(input) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].transpose(0, 1).unsqueeze(0) dec_attn_mask = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input emb = emb.transpose(0, 1) input = input.transpose(0, 1) abs_pos = None mem_len = 0 mems = None if self.unidirectional: qlen = input.size(0) klen = qlen + mem_len attn_mask_src = torch.triu( emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] pad_mask = mask_src mask_src = pad_mask + attn_mask_src # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) mask_src = mask_src.gt(0) if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ qlen = input.size(0) klen = qlen + mem_len # Asynchronous positions: 2K+1 positions instead of K+1 if self.unidirectional: pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) else: pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype) # pos_emb has size 2T+1 x 1 x H pos_emb = self.positional_encoder(pos, bsz=input.size(1) if self.fast_self_attn else None) if self.learnable_position_encoding: raise NotImplementedError # B x T x H -> T x B x H context = emb if streaming: hids = [context] # Apply dropout to both context and pos_emb context = self.preprocess_layer(context) pos_emb = self.preprocess_layer(pos_emb) if self.reversible: context = torch.cat([context, context], dim=-1) assert streaming is not True, "Streaming and Reversible is not usable yet." # print(context.size(), pos_emb.size()) context = ReversibleEncoderFunction.apply(context, pos_emb, self.layer_modules, mask_src) else: for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model mems_i = mems[i] if mems is not None and streaming and self.max_memory_size > 0 else None context = layer(context, pos_emb, mask_src, mems=mems_i) if streaming: hids.append(context) # final layer norm context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, {'context': context, 'src_mask': dec_attn_mask, 'src': input}) if streaming: # streaming_state.prev_src_mem_size += sum(input_length.tolist()) # streaming_state.prune_source_memory(self.max_memory_size) streaming_state.update_src_mems(hids, qlen) output_dict['streaming_state'] = streaming_state return output_dict
def forward(self, input, context, src, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x src_len x d_model mask_src (Tensor) batch_size x src_len Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x src_len """ """ Embedding: batch_size x len_tgt x d_model """ input = input.transpose(0, 1) # T x B emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) emb = emb * math.sqrt(self.model_size) if streaming: src_lengths = kwargs.get("src_lengths", None) tgt_lengths = kwargs.get("tgt_lengths", None) streaming_state = kwargs.get("streaming_state") # mems = streaming_state.tgt_mems mem_len = streaming_state.prev_tgt_mem_size extra_context = streaming_state.extra_context extra_context_length = extra_context.size( 0) if extra_context is not None else 0 # mem_len = mems[0].size(0) if mems is not None else 0 else: mem_len = 0 mems = None extra_context = None if self.double_position: assert input_pos is not None tgt_len, bsz = input_pos.size(0), input_pos.size(1) input_pos_ = input_pos.view(-1).type_as(emb) abs_pos = self.positional_encoder(input_pos_).squeeze(1).view( tgt_len, bsz, -1) emb = emb + abs_pos if self.use_language_embedding: lang_emb = self.language_embeddings(input_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb elif self.language_embedding_type == 'concat': # replace the bos embedding with the language bos_emb = lang_emb.expand_as(emb[0]) emb[0] = bos_emb lang_emb = lang_emb.unsqueeze(0).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError if context is not None: if self.encoder_type == "audio": if not self.encoder_cnn_downsampling: mask_src = src.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) else: long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD) mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1) else: if streaming: context_attn_mask = self.create_context_mask( input, src, src_lengths, tgt_lengths, extra_context_length) mask_src = context_attn_mask.unsqueeze(0) else: mask_src = src.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None qlen = input.size(0) klen = qlen + mem_len # preparing self-attention mask. The input is either left or right aligned if streaming: dec_attn_mask = self.create_self_attn_mask(input, tgt_lengths, mem_len) else: dec_attn_mask = torch.triu(emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] pad_mask = input.eq(onmt.constants.PAD).byte() # L x B dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) dec_attn_mask = dec_attn_mask.gt(0) if onmt.constants.torch_version >= 1.2: dec_attn_mask = dec_attn_mask.bool() pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) output = self.preprocess_layer(emb.contiguous()) if streaming: hids = [output] if extra_context is not None: context = torch.cat([extra_context, context], dim=0) # print(context.size(), context_attn_mask.size()) for i, layer in enumerate(self.layer_modules): # batch_size x src_len x d_model output, coverage = layer(output, context, pos_emb, self.r_w_bias, # self.r_r_bias, dec_attn_mask, mask_src) # mems_i = mems[i] if mems is not None and streaming and # self.stream_context in ['local', 'global'] else None if streaming: buffer = streaming_state.tgt_buffer[i] output, coverage, buffer = layer(output, context, dec_attn_mask, context_attn_mask, incremental=True, incremental_cache=buffer, reuse_source=False) streaming_state.tgt_buffer[i] = buffer else: output, coverage, _ = layer(output, context, dec_attn_mask, mask_src) # if streaming: # hids.append(output) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. output = self.postprocess_layer(output) output_dict = { 'hidden': output, 'coverage': coverage, 'context': context } output_dict = defaultdict(lambda: None, output_dict) if streaming: streaming_state.prev_tgt_mem_size += sum(tgt_lengths.tolist()) streaming_state.prune_target_memory(self.max_memory_size) # if we use the extra context: keep the last context if self.extra_context_size > 0: extra_context = context[-self.extra_context_size:].detach() streaming_state.extra_context = extra_context # if self.stream_context in ['local', 'global']: # streaming_state.update_tgt_mems(hids, qlen) output_dict['streaming_state'] = streaming_state return output_dict
def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ if self.input_type == "text": bsz_first_input = input input = input.transpose(0, 1) # mask_src = input.eq(onmt.constants.PAD).unsqueeze(0) # batch_size x src_len x 1 for broadcasting dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) if streaming: raise NotImplementedError streaming_state = kwargs.get('streaming_state', None) mems = streaming_state.src_mems # mem_len = streaming_state.src_mems[0].size(0) mem_len = streaming_state.prev_src_mem_size input_length = kwargs.get('src_lengths', None) streaming_state = kwargs.get('streaming_state', None) mask_src = self.create_stream_mask(input, input_length, mem_len) mask_src = mask_src.unsqueeze(2) else: mem_len = 0 mask_src = input.eq(onmt.constants.PAD).unsqueeze( 0) # batch_size x src_len x 1 for broadcasting mems = None emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.double_position: assert input_pos is not None # flatten src_len, bsz = input_pos.size(0), input_pos.size(1) input_pos_ = input_pos.contiguous().view(-1).type_as(emb) abs_pos = self.positional_encoder(input_pos_) abs_pos = abs_pos.squeeze(1).view(src_len, bsz, -1) else: abs_pos = None """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(1) else: if streaming: raise NotImplementedError if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq( onmt.constants.PAD).unsqueeze(0) dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].transpose().unsqueeze(0) dec_attn_mask = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input emb = emb.transpose(0, 1) input = input.transpose(0, 1) abs_pos = None mem_len = 0 if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) if self.double_position and abs_pos is not None: # adding position encoding emb = emb + abs_pos """ Adding positional encoding """ qlen = input.size(0) klen = qlen + mem_len # Asynchronous positions: 2K+1 positions instead of K+1 # because the batch dimension is lacking # B x T x H -> T x B x H context = emb # Apply dropout to both context and pos_emb context = self.preprocess_layer(context) for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model if streaming: buffer = streaming_state.src_buffer[i] context, buffer = layer(context, mask_src, incremental=True, incremental_cache=buffer) streaming_state.src_buffer[i] = buffer else: context = layer(context, mask_src) # last layer norm context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, { 'context': context, 'src_mask': dec_attn_mask, 'src': input }) if streaming: streaming_state.prev_src_mem_size += sum(input_length.tolist()) streaming_state.prune_source_memory(self.max_memory_size) # streaming_state.update_src_mems(hids, qlen) output_dict['streaming_state'] = streaming_state return output_dict
def forward(self, input, context, src): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x len_src x d_model mask_src (Tensor) batch_size x len_src Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x len_src """ """ Embedding: batch_size x len_tgt x d_model """ emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if isinstance(emb, tuple): emb = emb[0] emb = self.preprocess_layer(emb) mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1) pad_mask_src = torch.autograd.Variable(src.data.ne(onmt.constants.PAD)) len_tgt = input.size(1) mask_tgt = input.data.eq( onmt.constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt] mask_tgt = torch.gt(mask_tgt, 0) output = emb.contiguous() pad_mask_tgt = torch.autograd.Variable( input.data.ne(onmt.constants.PAD)) # batch_size x len_src pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1)) memory_bank = None for i, layer in enumerate(self.layer_modules): if len(self.layer_modules ) - i <= onmt.constants.checkpointing and self.training: output, memory_bank, coverage = checkpoint( custom_layer(layer), output, context, memory_bank, mask_tgt, mask_src, pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model else: output, memory_bank, coverage = layer( output, context, memory_bank, mask_tgt, mask_src, pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. output = self.postprocess_layer(output) return output, coverage
def step(self, input, decoder_state): src = decoder_state.src.transpose( 0, 1) if decoder_state.src is not None else None tgt = input tgt_lang = decoder_state.tgt_lang src_lang = decoder_state.src_lang # print(src.size(), tgt.size()) # print(src_lang, tgt_lang) tgt_len = tgt.size(1) src_len = src.size(1) bsz = tgt.size(0) # Embedding stage (and scale the embedding) src_emb = embedded_dropout(self.src_embedding, src, dropout=self.word_dropout if self.training else 0) \ * math.sqrt(self.model_size) tgt_emb = embedded_dropout(self.tgt_embedding, tgt, dropout=self.word_dropout if self.training else 0) \ * math.sqrt(self.model_size) # Add position encoding src_emb = self.time_transformer(src_emb) tgt_emb = self.time_transformer(tgt_emb) if self.use_language_embedding: if self.language_embedding_type in ["sum", "all_sum"]: src_lang_emb = self.language_embeddings(src_lang) src_emb += src_lang_emb.unsqueeze(1) tgt_lang_emb = self.language_embeddings(tgt_lang) tgt_emb += tgt_lang_emb.unsqueeze(1) # concatenate embedding emb = torch.cat([src_emb, tgt_emb], dim=1) # L x batch_size x H # prepare self-attention mask # For the source: we have two different parts # [1 x src_len x batch_size] # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(0).byte() # src_pad_mask = mask_src_src # # Attention from src to target: everything is padded # mask_src_tgt = mask_src_src.new_ones(1, 1, 1).expand(src_len, tgt_len, bsz) # # [src_len x L x batch_size] # mask_src = torch.cat([mask_src_src.expand(src_len, src_len, bsz), mask_src_tgt], dim=1) # mask_src = mask_src.bool() # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(1).byte() # B x 1 x src_len # mask_src_tgt = mask_src_src.new_ones(bsz, src_len, tgt_len) # bsz x src_len x tgt_len # # mask_src = torch.cat([mask_src_src.expand(bsz, src_len, src_len), mask_src_tgt], dim=-1) # # # For the target: # mask_tgt_tgt = tgt.eq(onmt.constants.PAD).byte().unsqueeze(1) + self.mask[:tgt_len, :tgt_len] # mask_tgt_tgt = torch.gt(mask_tgt_tgt, 0).byte() # bsz x tgt_len x tgt_len # # mask_tgt_src = mask_tgt_tgt.new_zeros(bsz, tgt_len, src_len) + src.eq(onmt.constants.PAD).unsqueeze(1).byte() # mask_tgt = torch.cat([mask_tgt_src, mask_tgt_tgt], dim=-1) # bsz x tgt_len x T # attn_mask = torch.cat([mask_src, mask_tgt], dim=1).bool() # L x L x batch_size attn_mask = self.gen_mask(src, input) # seq = torch.cat([src, input], dim=-1) # seq_len = seq.size(1) # attn_mask = self.mask[:seq_len, :seq_len] + seq.eq(onmt.constants.PAD).byte().unsqueeze(1) # attn_mask = torch.gt(attn_mask, 0).bool() output = emb # Applying dropout and tranpose to T x B x H output = self.preprocess_layer(output).transpose(0, 1) # FORWARD PASS coverage = None for i, layer in enumerate(self.layer_modules): output, coverage = layer(output, None, attn_mask, None) # context and context_mask are None # Final normalization output = self.postprocess_layer(output) output = output[-1:, :, :] output_dict = defaultdict(lambda: None) output_dict['hidden'] = output logprobs = self.generator[0](output_dict).squeeze(0) output_dict['src'] = decoder_state.src.transpose(0, 1) output_dict['log_prob'] = logprobs output_dict['coverage'] = logprobs.new(bsz, tgt_len, src_len).zero_() # buffers = decoder_state.attention_buffers # tgt_lang = decoder_state.tgt_lang # src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None # # if decoder_state.concat_input_seq: # if decoder_state.input_seq is None: # decoder_state.input_seq = input # else: # # concatenate the last input to the previous input sequence # decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0) # # # For Transformer, both inputs are assumed as B x T (batch first) # input = decoder_state.input_seq.transpose(0, 1) # src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None # # if input.size(1) > 1: # input_ = input[:, -1].unsqueeze(1) # else: # input_ = input # """ Embedding: batch_size x 1 x d_model """ # # check = input_.gt(self.word_lut.num_embeddings) # print(input.size()) # emb = self.tgt_embedding(input_) * math.sqrt(self.model_size) # # """ Adding positional encoding """ # emb = self.time_transformer(emb, t=input.size(1)) # # if self.use_language_embedding: # if self.language_embedding_type in ["sum", "all_sum"]: # # tgt_lang_emb = self.language_embeddings(tgt_lang) # emb += tgt_lang_emb.unsqueeze(1) # # emb = emb.transpose(0, 1) # # # attention mask For the target: # tgt_len = input.size(1) # bsz = input.size(0) # src_len = src.size(1) # mask_tgt_tgt = input.eq(onmt.constants.PAD).byte().unsqueeze(1) + self.mask[:tgt_len, :tgt_len] # mask_tgt_tgt = torch.gt(mask_tgt_tgt, 0).byte() # bsz x tgt_len x tgt_len # # mask_tgt_src = mask_tgt_tgt.new_zeros(bsz, tgt_len, src_len) + src.eq(onmt.constants.PAD).unsqueeze(1).byte() # # mask_tgt = torch.cat([mask_tgt_src, mask_tgt_tgt], dim=-1) # bsz x tgt_len x T # # # take the last element of the 'target sequence' for the mask # attn_mask = mask_tgt[:, -1, :].unsqueeze(1).bool() # # output = emb # # for i, layer in enumerate(self.layer_modules): # buffer = buffers[i] if i in buffers else None # assert (output.size(0) == 1) # # output, coverage, buffer = layer.step(output, None, attn_mask, None, buffer=buffer) # # decoder_state.update_attention_buffer(buffer, i) # # # Final normalization # output_dict = defaultdict(lambda: None) # output_dict['hidden'] = output # # logprobs = self.generator[0](output_dict).squeeze(0) # # output_dict['src'] = decoder_state.src.transpose(0, 1) # output_dict['log_prob'] = logprobs # output_dict['coverage'] = logprobs.new(bsz, tgt_len, src_len).zero_() return output_dict
def forward(self, input, context, src, input_pos=None, src_lang=None, tgt_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x src_len x d_model mask_src (Tensor) batch_size x src_len Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x src_len """ """ Embedding: batch_size x len_tgt x d_model """ input = input.transpose(0, 1) # T x B emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) emb = emb * math.sqrt(self.model_size) mem_len = 0 mems = None extra_context = None if self.use_language_embedding: lang_emb = self.language_embeddings(tgt_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb elif self.language_embedding_type == 'concat': lang_emb = lang_emb.unsqueeze(0).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError if context is not None: if self.encoder_type == "audio": if not self.encoder_cnn_downsampling: mask_src = src.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) else: long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD) mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1) else: mask_src = src.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None qlen = input.size(0) klen = qlen + mem_len # preparing self-attention mask. The input must be left-aligned dec_attn_mask = torch.triu(emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] dec_attn_mask = dec_attn_mask.bool() # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) if not self.learnable_position_encoding: pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) pos_emb = self.positional_encoder(pos, bsz=input.size(1)) pos_emb = self.preprocess_layer(pos_emb) else: range_vec = torch.arange(klen, device=emb.device) range_mat = range_vec.unsqueeze(-1).expand(-1, klen).transpose(0, 1) distance_mat = range_vec - range_mat.transpose(0, 1) distance_mat.clamp_(-self.max_pos_length, self.max_pos_length).add_(self.max_pos_length) pos_emb = distance_mat # pos_emb = self.positional_encoder(pos, bsz=input.size(1)) output = self.preprocess_layer(emb.contiguous()) # pos_emb = self.preprocess_layer(pos_emb) lfv_vector, lid_logits = None, list() if self.mpw: src_lang = self.factor_embeddings(src_lang).squeeze(0) tgt_lang = self.factor_embeddings(tgt_lang).squeeze(0) assert src_lang.ndim == 1 and tgt_lang.ndim == 1 for i, layer in enumerate(self.layer_modules): output, coverage, _ = layer(output, context, pos_emb, lfv_vector, dec_attn_mask, mask_src, src_lang=src_lang, tgt_lang=tgt_lang) output = self.postprocess_layer(output, factor=tgt_lang) output_dict = { 'hidden': output, 'coverage': coverage, 'context': context, 'lid_logits': lid_logits } output_dict = defaultdict(lambda: None, output_dict) return output_dict
def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ bsz_first_input = input input = input.transpose(0, 1) dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) mem_len = 0 mems = None emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.early_emb_scale: """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(0) """ Adding positional encoding """ qlen = input.size(0) klen = qlen + mem_len # Asynchronous positions: 2K+1 positions instead of K+1 if not self.absolute_position_encoding: if not self.learnable_position_encoding: pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype) # pos_emb has size 2T+1 x 1 x H pos_emb = self.positional_encoder(pos, bsz=input.size(1)) pos_emb = self.preprocess_layer(pos_emb) else: range_vec = torch.arange(klen, device=emb.device) range_mat = range_vec.unsqueeze(-1).expand(-1, klen).transpose( 0, 1) distance_mat = range_vec - range_mat.transpose(0, 1) distance_mat.clamp_(-self.max_pos_length, self.max_pos_length).add_( self.max_pos_length) pos_emb = distance_mat # pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device).long() # pos.clamp_(-self.max_pos_length, self.max_pos_length).add_(self.max_pos_length) # pos_emb = pos.unsqueeze(1) mask_src = input.eq(onmt.constants.PAD).unsqueeze( 0) # 1 x src_len x batch_size for broadcasting else: # Absolute position encoding from 0 -> n pos, pos_emb = None, None emb = self.positional_encoder(emb.transpose(0, 1)).transpose(0, 1) mask_src = bsz_first_input.eq( onmt.constants.PAD) # batch_size x src_len if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() if not self.early_emb_scale: """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) # context size is now T x B x H context = self.preprocess_layer(emb) if self.reversible: context = reversible_encoder(self.layer_modules, context, pos_emb, mask_src) else: for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model context = layer(context, pos_emb, mask_src, src_lang=input_lang) # if self.checkpointing == 0 or self.training is False: # context = layer(context, pos_emb, mask_src, src_lang=input_lang) # else: # context = checkpoint(create_forward_function(layer), context, pos_emb, mask_src, input_lang) # final layer norm. we can consider this layer norm as a part of the output layer/function context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, { 'context': context, 'src_mask': dec_attn_mask, 'src': input }) return output_dict
def forward(self, input, context, src, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x src_len x d_model mask_src (Tensor) batch_size x src_len Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x src_len """ """ Embedding: batch_size x len_tgt x d_model """ emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) if self.use_language_embedding: lang_emb = self.language_embeddings(input_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb elif self.language_embedding_type == 'concat': # replace the bos embedding with the language bos_emb = lang_emb.expand_as(emb[:, 0, :]) emb[:, 0, :] = bos_emb lang_emb = lang_emb.unsqueeze(1).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError if context is not None: if self.encoder_type == "audio": if not self.encoder_cnn_downsampling: mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) else: long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD) mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1) else: mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None len_tgt = input.size(1) mask_tgt = torch.triu(emb.new_ones(len_tgt, len_tgt), diagonal=1).byte().unsqueeze(0) mask_tgt = mask_tgt.bool() time_embedding = self.positional_encoder.get_positional_embeddings(emb) output = self.preprocess_layer(emb.transpose(0, 1).contiguous()) for i in range(self.max_layers): layer_tensor = torch.LongTensor([i]).to(output.device) layer_embedding = self.layer_embeddings(layer_tensor) output, coverage, _ = self.universal_layer(output, time_embedding, layer_embedding, context, mask_tgt, mask_src) # last layer norm output = self.postprocess_layer(output) output_dict = { 'hidden': output, 'coverage': coverage, 'context': context } output_dict = defaultdict(lambda: None, output_dict) return output_dict
def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ if self.input_type == "text": bsz_first_input = input input = input.transpose(0, 1) # mask_src = input.eq(onmt.constants.PAD).unsqueeze(0) # batch_size x src_len x 1 for broadcasting dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) if streaming: raise NotImplementedError streaming_state = kwargs.get('streaming_state', None) mems = streaming_state.src_mems # mem_len = streaming_state.src_mems[0].size(0) mem_len = streaming_state.prev_src_mem_size input_length = kwargs.get('src_lengths', None) streaming_state = kwargs.get('streaming_state', None) mask_src = self.create_stream_mask(input, input_length, mem_len) mask_src = mask_src.unsqueeze(2) else: mem_len = 0 mask_src = input.eq(onmt.constants.PAD).unsqueeze( 0) # batch_size x src_len x 1 for broadcasting mems = None emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.double_position: assert input_pos is not None # flatten src_len, bsz = input_pos.size(0), input_pos.size(1) input_pos_ = input_pos.contiguous().view(-1).type_as(emb) abs_pos = self.positional_encoder(input_pos_) abs_pos = abs_pos.squeeze(1).view(src_len, bsz, -1) else: abs_pos = None """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(1) else: if streaming: raise NotImplementedError if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq( onmt.constants.PAD).unsqueeze(0) dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].transpose().unsqueeze(0) dec_attn_mask = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input emb = emb.transpose(0, 1) input = input.transpose(0, 1) abs_pos = None mem_len = 0 if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) if self.double_position and abs_pos is not None: # adding position encoding emb = emb + abs_pos """ Adding positional encoding """ qlen = input.size(0) klen = qlen + mem_len # Asynchronous positions: 2K+1 positions instead of K+1 pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device).long() # because the batch dimension is lacking pos_emb = self.positional_encoder(pos).unsqueeze(1) # B x T x H -> T x B x H context = emb # Apply dropout to both context and pos_emb context = self.preprocess_layer(context) pos_emb = self.preprocess_layer(pos_emb) for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model if streaming: buffer = streaming_state.src_buffer[i] context, buffer = layer(context, pos_emb, mask_src, incremental=True, incremental_cache=buffer) streaming_state.src_buffer[i] = buffer else: context = layer(context, pos_emb, mask_src) # last layer norm context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, { 'context': context, 'src_mask': dec_attn_mask, 'src': input }) if streaming: streaming_state.prev_src_mem_size += sum(input_length.tolist()) streaming_state.prune_source_memory(self.max_memory_size) # streaming_state.update_src_mems(hids, qlen) output_dict['streaming_state'] = streaming_state return output_dict # class RelativeTransformerDecoder(TransformerDecoder): # # def __init__(self, opt, dicts, positional_encoder, language_embeddings=None, ignore_source=False): # # self.death_rate = opt.death_rate # self.double_position = opt.double_position # self.max_memory_size = opt.max_memory_size # self.stream_context = opt.stream_context # self.extra_context_size = opt.extra_context_size # # # build_modules will be called from the inherited constructor # super(RelativeTransformerDecoder, self).__init__(opt, dicts, # positional_encoder, # language_embeddings, # ignore_source, # allocate_positions=False) # self.positional_encoder = SinusoidalPositionalEmbedding(opt.model_size) # self.d_head = self.model_size // self.n_heads # # Parameters for the position biases # self.r_w_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_head)) # self.r_r_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_head)) # # def renew_buffer(self, new_len): # return # # def build_modules(self): # # e_length = expected_length(self.layers, self.death_rate) # # print("* Transformer Decoder with Relative Attention with %.2f expected layers" % e_length) # # self.layer_modules = nn.ModuleList() # # for l in range(self.layers): # # linearly decay the death rate # death_r = (l + 1.0) / self.layers * self.death_rate # # block = RelativeTransformerDecoderLayer(self.n_heads, self.model_size, # self.dropout, self.inner_size, self.attn_dropout, # variational=self.variational_dropout, death_rate=death_r) # # self.layer_modules.append(block) # # def process_embedding(self, input, input_lang=None): # # return input # # def create_context_mask(self, input, src, src_lengths, tgt_lengths, extra_context_length=0): # """ # Generate the mask so that part of the target attends to a part of the source # :param extra_context_length: # :param input: # :param src: # :param src_lengths: # :param tgt_lengths: # :return: # """ # # mask = None # # if self.stream_context == 'global': # # Global context: one target attends to everything in the source # for (src_length, tgt_length) in zip(src_lengths, tgt_lengths): # # if mask is None: # prev_src_length = 0 # prev_tgt_length = 0 # else: # prev_src_length, prev_tgt_length = mask.size(1), mask.size(0) # # # current sent attend to current src sent and all src in the past # current_mask = input.new_zeros(tgt_length, src_length + prev_src_length) # # # the previous target cannot attend to the current source # if prev_tgt_length > 0: # prev_mask = input.new_ones(prev_tgt_length, src_length) # prev_mask = torch.cat([mask, prev_mask], dim=-1) # else: # prev_mask = None # # # the output mask has two parts: the prev and the current # if prev_mask is not None: # mask = torch.cat([prev_mask, current_mask], dim=0) # else: # mask = current_mask # # # elif self.stream_context == 'local_xl': # # # Local extra context: only attends to the aligned context + extra mem # # # This mode ensures that all target sentences have the same memory, not uneven like "global" # # # # for (src_length, tgt_length) in zip(src_lengths, tgt_lengths): # # # # # First: we read the existing mask to know where we are # # if mask is None: # # prev_src_length = 0 # # prev_tgt_length = 0 # # else: # # prev_src_length, prev_tgt_length = mask.size(1), mask.size(0) # # # # # current tgt sent attend to only current src sent # # if prev_src_length > 0: # # current_mask = torch.cat([input.new_ones(tgt_length, prev_src_length - extra_context_length), # # input.new_zeros(tgt_length, src_length + extra_context_length)], dim=-1) # # else: # # current_mask = input.new_zeros(tgt_length, src_length + extra_context_length) # # # # # the previous target cannot attend to the current source # # if prev_tgt_length > 0: # # prev_mask = input.new_ones(prev_tgt_length, src_length) # # prev_mask = torch.cat([mask, prev_mask], dim=-1) # # else: # # prev_mask = None # # # # # the output mask has two parts: the prev and the current # # if prev_mask is not None: # # mask = torch.cat([prev_mask, current_mask], dim=0) # # else: # # mask = current_mask # # elif self.stream_context in ['local', 'limited']: # # Local context: only attends to the aligned context # for (src_length, tgt_length) in zip(src_lengths, tgt_lengths): # # if mask is None: # prev_src_length = 0 # prev_tgt_length = 0 # else: # prev_src_length, prev_tgt_length = mask.size(1), mask.size(0) # # # current tgt sent attend to only current src sent # if prev_src_length > 0: # current_mask = torch.cat([input.new_ones(tgt_length, prev_src_length - extra_context_length), # input.new_zeros(tgt_length, src_length + extra_context_length)], dim=-1) # else: # current_mask = input.new_zeros(tgt_length, src_length + extra_context_length) # # # the previous target cannot attend to the current source # if prev_tgt_length > 0: # prev_mask = input.new_ones(prev_tgt_length, src_length) # prev_mask = torch.cat([mask, prev_mask], dim=-1) # else: # prev_mask = None # # # the output mask has two parts: the prev and the current # if prev_mask is not None: # mask = torch.cat([prev_mask, current_mask], dim=0) # else: # mask = current_mask # # mask = mask.bool() # return mask # # def create_self_attn_mask(self, input, tgt_lengths, prev_tgt_mem_size): # """ # Create a mask for the target words attending to the past # :param input: # :param tgt_lengths: # :param prev_tgt_mem_size: # :return: # """ # # if self.stream_context in ['local', 'global']: # qlen = sum(tgt_lengths.tolist()) # mlen = prev_tgt_mem_size # klen = qlen + mlen # mask = torch.triu(input.new_ones(qlen, klen), diagonal=1 + mlen).bool()[:, :, None] # elif self.stream_context in ['limited']: # # # past_length = prev_tgt_mem_size # mask = None # # assert prev_tgt_mem_size == 0, "This model is limited and doesn't accept memory" # # for length in tgt_lengths: # # past_length = mask.size(0) if mask is not None else 0 # # if past_length > 0: # # don't look at the past # past_mask = input.new_ones(length, past_length) # else: # past_mask = None # # # pay attention to the past words in the current sentence # current_mask = torch.triu(input.new_ones(length, length), diagonal=1) # # if past_mask is not None: # current_mask = torch.cat([past_mask, current_mask], dim=1) # # if mask is None: # mask = current_mask # else: # no_future_mask = input.new_ones(past_length, length) # mask = torch.cat([mask, no_future_mask], dim=1) # mask = torch.cat([mask, current_mask], dim=0) # # mask = mask.bool().unsqueeze(-1) # # return mask # # # TODO: merging forward_stream and forward # # TODO: write a step function for encoder # # def forward(self, input, context, src, input_pos=None, input_lang=None, streaming=False, **kwargs): # """ # Inputs Shapes: # input: (Variable) batch_size x len_tgt (wanna tranpose) # context: (Variable) batch_size x src_len x d_model # mask_src (Tensor) batch_size x src_len # Outputs Shapes: # out: batch_size x len_tgt x d_model # coverage: batch_size x len_tgt x src_len # # """ # # """ Embedding: batch_size x len_tgt x d_model """ # input = input.transpose(0, 1) # T x B # emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0) # emb = emb * math.sqrt(self.model_size) # # if streaming: # src_lengths = kwargs.get("src_lengths", None) # tgt_lengths = kwargs.get("tgt_lengths", None) # streaming_state = kwargs.get("streaming_state") # # mems = streaming_state.tgt_mems # mem_len = streaming_state.prev_tgt_mem_size # extra_context = streaming_state.extra_context # extra_context_length = extra_context.size(0) if extra_context is not None else 0 # # mem_len = mems[0].size(0) if mems is not None else 0 # else: # mem_len = 0 # mems = None # extra_context = None # # if self.double_position: # assert input_pos is not None # tgt_len, bsz = input_pos.size(0), input_pos.size(1) # input_pos_ = input_pos.view(-1).type_as(emb) # abs_pos = self.positional_encoder(input_pos_).squeeze(1).view(tgt_len, bsz, -1) # # emb = emb + abs_pos # # if self.use_language_embedding: # lang_emb = self.language_embeddings(input_lang) # B x H or 1 x H # if self.language_embedding_type == 'sum': # emb = emb + lang_emb # elif self.language_embedding_type == 'concat': # # replace the bos embedding with the language # bos_emb = lang_emb.expand_as(emb[0]) # emb[0] = bos_emb # # lang_emb = lang_emb.unsqueeze(0).expand_as(emb) # concat_emb = torch.cat([emb, lang_emb], dim=-1) # emb = torch.relu(self.projector(concat_emb)) # else: # raise NotImplementedError # # if context is not None: # if self.encoder_type == "audio": # if not self.encoder_cnn_downsampling: # mask_src = src.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1) # else: # long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) # mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1) # else: # if streaming: # context_attn_mask = self.create_context_mask(input, src, # src_lengths, tgt_lengths, # extra_context_length) # mask_src = context_attn_mask.unsqueeze(0) # else: # mask_src = src.eq(onmt.constants.PAD).unsqueeze(1) # else: # mask_src = None # # qlen = input.size(0) # klen = qlen + mem_len # # preparing self-attention mask. The input is either left or right aligned # # if streaming: # dec_attn_mask = self.create_self_attn_mask(input, tgt_lengths, mem_len) # else: # dec_attn_mask = torch.triu( # emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] # pad_mask = input.eq(onmt.constants.PAD).byte() # L x B # # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) # dec_attn_mask = dec_attn_mask.gt(0) # if onmt.constants.torch_version >= 1.2: # dec_attn_mask = dec_attn_mask.bool() # # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) # # pos_emb = self.positional_encoder(pos) # # output = self.preprocess_layer(emb.contiguous()) # # if streaming: # hids = [output] # if extra_context is not None: # context = torch.cat([extra_context, context], dim=0) # # print(context.size(), context_attn_mask.size()) # # pos_emb = self.preprocess_layer(pos_emb) # # for i, layer in enumerate(self.layer_modules): # # batch_size x src_len x d_model output, coverage = layer(output, context, pos_emb, self.r_w_bias, # # self.r_r_bias, dec_attn_mask, mask_src) # # mems_i = mems[i] if mems is not None and streaming and # # self.stream_context in ['local', 'global'] else None # if streaming: # buffer = streaming_state.tgt_buffer[i] # output, coverage, buffer = layer(output, context, pos_emb, dec_attn_mask, context_attn_mask, # incremental=True, incremental_cache=buffer, reuse_source=False) # streaming_state.tgt_buffer[i] = buffer # else: # output, coverage, _ = layer(output, context, pos_emb, dec_attn_mask, mask_src ) # # if streaming: # # hids.append(output) # # # From Google T2T # # if normalization is done in layer_preprocess, then it should also be done # # on the output, since the output can grow very large, being the sum of # # a whole stack of unnormalized layer outputs. # output = self.postprocess_layer(output) # # output_dict = {'hidden': output, 'coverage': coverage, 'context': context} # output_dict = defaultdict(lambda: None, output_dict) # # if streaming: # streaming_state.prev_tgt_mem_size += sum(tgt_lengths.tolist()) # streaming_state.prune_target_memory(self.max_memory_size) # # # if we use the extra context: keep the last context # if self.extra_context_size > 0: # extra_context = context[-self.extra_context_size:].detach() # streaming_state.extra_context = extra_context # # # if self.stream_context in ['local', 'global']: # # streaming_state.update_tgt_mems(hids, qlen) # output_dict['streaming_state'] = streaming_state # # return output_dict # # def step(self, input, decoder_state, streaming=False): # """ # Inputs Shapes: # input: (Variable) batch_size x len_tgt (wanna tranpose) # context: (Variable) batch_size x src_len x d_model # mask_src (Tensor) batch_size x src_len # buffer (List of tensors) List of batch_size * len_tgt-1 * d_model for self-attention recomputing # Outputs Shapes: # out: batch_size x len_tgt x d_model # coverage: batch_size x len_tgt x src_len # # """ # # if streaming: # return self.step_streaming(input, decoder_state) # # context = decoder_state.context # buffers = decoder_state.attention_buffers # lang = decoder_state.tgt_lang # mask_src = decoder_state.src_mask # # if decoder_state.concat_input_seq: # if decoder_state.input_seq is None: # decoder_state.input_seq = input # else: # # concatenate the last input to the previous input sequence # decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0) # input = decoder_state.input_seq.transpose(0, 1) # B x T # # src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None # # # use the last value of input to continue decoding # if input.size(1) > 1: # input_ = input[:, -1].unsqueeze(1).transpose(0, 1) # else: # input_ = input.transpose(0, 1) # # """ Embedding: batch_size x 1 x d_model """ # emb = self.word_lut(input_) * math.sqrt(self.model_size) # input = input.transpose(0, 1) # klen = input.size(0) # # emb = self.word_lut(input) * math.sqrt(self.model_size) # # if self.double_position: # input_pos = torch.arange(input.size(0), dtype=emb.dtype, device=emb.device) # input_pos = input_pos.unsqueeze(1).repeat(1, input.size(1)) # tgt_len, bsz = input_pos.size(0), input_pos.size(1) # input_pos_ = input_pos.view(-1).type_as(emb) # abs_pos = self.positional_encoder(input_pos_).squeeze(1).view(tgt_len, bsz, -1) # emb = emb + abs_pos[-1:, :, :] # # if self.use_language_embedding: # lang_emb = self.language_embeddings(lang) # B x H # # if self.language_embedding_type in ['sum', 'all_sum']: # emb = emb + lang_emb # elif self.language_embedding_type == 'concat': # if input.size(0) == 1: # emb[0] = lang_emb # # lang_emb = lang_emb.unsqueeze(0).expand_as(emb) # concat_emb = torch.cat([emb, lang_emb], dim=-1) # emb = torch.relu(self.projector(concat_emb)) # else: # raise NotImplementedError # # # prepare position encoding # qlen = emb.size(0) # mlen = klen - qlen # # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) # # pos_emb = self.positional_encoder(pos) # # dec_attn_mask = torch.triu( # emb.new_ones(qlen, klen), diagonal=1 + mlen).byte()[:, :, None] # # pad_mask = input.eq(onmt.constants.PAD).byte() # L x B # # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) # dec_attn_mask = dec_attn_mask.gt(0) # # if onmt.constants.torch_version >= 1.2: # dec_attn_mask = dec_attn_mask.bool() # # if context is not None: # if self.encoder_type == "audio": # if not self.encoder_cnn_downsampling: # mask_src = src.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1) # else: # long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) # mask_src = long_mask[:, 0:context.size(0) * 4:4].unsqueeze(1) # else: # # mask_src = src.eq(onmt.constants.PAD).unsqueeze(1) # else: # mask_src = None # # output = emb.contiguous() # # for i, layer in enumerate(self.layer_modules): # buffer = buffers[i] if i in buffers else None # # assert (output.size(0) == 1) # # # output, coverage, buffer = layer.step(output, context, pos_emb, # # dec_attn_mask, mask_src, buffer=buffer) # output, coverage, buffer = layer(output, context, pos_emb, dec_attn_mask, mask_src, # incremental=True, incremental_cache=buffer) # # decoder_state.update_attention_buffer(buffer, i) # # output = self.postprocess_layer(output) # output = output[-1].unsqueeze(0) # # output_dict = defaultdict(lambda: None) # output_dict['hidden'] = output # output_dict['coverage'] = coverage # output_dict['context'] = context # # return output_dict # # def step_streaming(self, input, decoder_state): # """Step function in streaming case""" # # context = decoder_state.context # lang = decoder_state.tgt_lang # streaming_state = decoder_state.streaming_state # # # for global model: push the context in # # if decoder_state.concat_input_seq: # if decoder_state.input_seq is None: # decoder_state.input_seq = input # else: # # concatenate the last input to the previous input sequence # decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0) # input = decoder_state.input_seq.transpose(0, 1) # B x T # # src = decoder_state.src.transpose(0, 1) if decoder_state.src is not None else None # # # use the last value of input to continue decoding # if input.size(1) > 1: # input_ = input[:, -1].unsqueeze(1).transpose(0, 1) # else: # input_ = input.transpose(0, 1) # # emb = self.word_lut(input_) * math.sqrt(self.model_size) # input = input.transpose(0, 1) # B x T to T x B # klen = input.size(0) # # # If we start a new sentence to decode: reset the context memory # if klen == 1: # streaming_state.reset_context_memory() # if self.stream_context == 'limited': # streaming_state.reset_target_memory() # # if self.use_language_embedding: # lang_emb = self.language_embeddings(lang) # B x H or 1 x H # if self.language_embedding_type == 'sum': # emb = emb + lang_emb # elif self.language_embedding_type == 'concat': # # replace the bos embedding with the language # bos_emb = lang_emb.expand_as(emb[0]) # emb[0] = bos_emb # # lang_emb = lang_emb.unsqueeze(0).expand_as(emb) # concat_emb = torch.cat([emb, lang_emb], dim=-1) # emb = torch.relu(self.projector(concat_emb)) # else: # raise NotImplementedError # # # need to manually definte src_lengths and tgt_lengths here # src_lengths = torch.LongTensor([context.size(0)]) # tgt_lengths = torch.LongTensor([1]) # # if context is not None: # context_attn_mask = self.create_context_mask(input, src, src_lengths, tgt_lengths) # context_attn_mask = context_attn_mask.unsqueeze(0) # else: # context_attn_mask = None # # dec_attn_mask = self.create_self_attn_mask(input, tgt_lengths, streaming_state.prev_tgt_mem_size) # # dec_attn_mask = dec_attn_mask[:, -1:, :] # # klen = 1 + streaming_state.prev_tgt_mem_size # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) # # pos_emb = self.positional_encoder(pos) # # output = emb # # for i, layer in enumerate(self.layer_modules): # # T x B x d_model # buffer = streaming_state.tgt_buffer[i] # # output, coverage = layer(output, context, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask, mask_src) # # reuse_source = True if input.size(1) == 1 else False # reuse_source = True # # # reuse source is True in this case because we can reuse the context ... # output, coverage, buffer = layer(output, context, pos_emb, dec_attn_mask, context_attn_mask, # incremental=True, incremental_cache=buffer, reuse_source=reuse_source) # streaming_state.tgt_buffer[i] = buffer # # output = self.postprocess_layer(output) # # streaming_state.prev_tgt_mem_size += 1 # streaming_state.prune_target_memory(self.max_memory_size + input.size(0)) # # extra_context = context[-self.extra_context_size:].detach() # # output_dict = defaultdict(lambda: None, {'hidden': output, 'coverage': coverage, 'context': context}) # output_dict['streaming_state'] = streaming_state # # return output_dict
def forward(self, input, **kwargs): """ Inputs Shapes: input: batch_size x len_src Outputs Shapes: out: batch_size x len_src x d_model mask_src """ # clean layer history self.history.clean() # Embedding: batch_size x len_src x d_model if self.input_type == "text": mask_src = input.data.eq(onmt.constants.PAD).unsqueeze( 1) # batch_size x len_src x 1 for broadcasting emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) else: mask_src = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) # Scale the emb by sqrt(d_model) emb = emb * math.sqrt(self.model_size) # Adding positional encoding emb = self.time_transformer(emb) # Dropout emb = self.preprocess_layer(emb) # B x T x H -> T x B x H context = emb.transpose(0, 1).contiguous() self.history.push(context) for i, layer in enumerate(self.layer_modules): context = self.history.pop() if len(self.layer_modules ) - i <= onmt.constants.checkpointing and self.training: context = checkpoint(custom_layer(layer), context, mask_src) else: context = layer(context, mask_src) # batch_size x len_src x d_model self.history.push(context) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. context = self.history.pop() context = self.postprocess_layer(context) output_dict = {'context': context, 'src_mask': mask_src} # return context, mask_src return output_dict
def forward(self, input, context, src, atbs=None, **kwargs): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x len_src x d_model mask_src (Tensor) batch_size x len_src Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x len_src """ """ Embedding: batch_size x len_tgt x d_model """ self.history.clean() emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if isinstance(emb, tuple): emb = emb[0] emb = self.preprocess_layer(emb) if self.use_feature: atb_emb = self.attribute_embeddings(atbs).unsqueeze(1).repeat( 1, emb.size(1)) # B x H to 1 x B x H emb = torch.cat([emb, atb_emb], dim=-1) emb = torch.relu(self.feature_projector(emb)) if context is not None: if self.encoder_type == "audio": mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) else: mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None if context is not None: if self.encoder_type == "audio": mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) else: mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None len_tgt = input.size(1) mask_tgt = input.data.eq( onmt.constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt] mask_tgt = torch.gt(mask_tgt, 0) output = emb.transpose(0, 1).contiguous() self.history.push(output) for i, layer in enumerate(self.layer_modules): output = self.history.pop() if len(self.layer_modules ) - i <= onmt.constants.checkpointing and self.training: output, coverage = checkpoint(custom_layer(layer), output, context, mask_tgt, mask_src) # batch_size x len_src x d_model else: output, coverage = layer( output, context, mask_tgt, mask_src) # batch_size x len_src x d_model # write into memory self.history.push(output) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. output = self.history.pop() output = self.postprocess_layer(output) output_dict = {'hidden': output, 'coverage': coverage} # return output, None return output_dict
def forward(self, input, input_pos=None, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ if self.input_type == "text": bsz_first_input = input input = input.transpose(0, 1) # mask_src = input.eq(onmt.constants.PAD).unsqueeze(1) # batch_size x src_len x 1 for broadcasting mask_src = input.eq(onmt.constants.PAD).unsqueeze(0) dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) # if self.double_position: # assert input_pos is not None # # flatten # src_len, bsz = input_pos.size(0), input_pos.size(1) # input_pos_ = input_pos.contiguous().view(-1).type_as(emb) # abs_pos = self.positional_encoder(input_pos_) # abs_pos = abs_pos.squeeze(1).view(src_len, bsz, -1) # # else: # abs_pos = None else: if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq( onmt.constants.PAD).unsqueeze(0) dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].transpose().unsqueeze(0) dec_attn_mask = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input emb = emb.transpose(0, 1) input = input.transpose(0, 1) abs_pos = None if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() # Scale the emb by sqrt(d_model) emb = emb * math.sqrt(self.model_size) # if self.double_position and abs_pos is not None: # # adding position encoding # emb = emb + abs_pos klen = input.size(0) # allocate positions: from L - 1 to -L + 1 pos = torch.arange(klen - 1, -klen + 1, -1.0, device=emb.device) # clamp the positions (all postions from afar are treated equally, maybe?) pos = torch.clamp(pos, -self.max_pos_length, self.max_pos_length) # L x 1 x H pos_emb = self.positional_encoder(pos.unsqueeze(1)) # Apply dropout to both context and pos_emb context = self.preprocess_layer(emb) pos_emb = self.preprocess_layer(pos_emb) for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model context = layer(context, pos_emb, mask_src) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. context = self.postprocess_layer(context) output_dict = { 'context': context, 'src_mask': dec_attn_mask, 'src': input } # return context, mask_src return output_dict
def forward(self, input, input_pos=None, input_lang=None, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ if self.input_type == "text": bsz_first_input = input input = input.transpose(0, 1) # mask_src = input.eq(onmt.constants.PAD).unsqueeze(1) # batch_size x src_len x 1 for broadcasting dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) mem_len = 0 mask_src = input.eq(onmt.constants.PAD).unsqueeze( 0) # batch_size x src_len x 1 for broadcasting mems = None emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(0) else: if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).transpose(0, 1).eq( onmt.constants.PAD).unsqueeze(0) dec_attn_mask = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) emb = emb.type_as(input) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].transpose().unsqueeze(0) dec_attn_mask = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input emb = emb.transpose(0, 1) input = input.transpose(0, 1) abs_pos = None mem_len = 0 mems = None if self.unidirectional: qlen = input.size(0) klen = qlen + mem_len attn_mask_src = torch.triu(emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] # pad_mask = mask_src # mask_src = pad_mask + attn_mask_src # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) # mask_src = mask_src.gt(0) # with right padding, causal mask covers the mask pad mask_src = attn_mask_src if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ positional encoding """ qlen = input.size(0) klen = qlen + mem_len # Asynchronous positions: 2K+1 positions instead of K+1 if self.unidirectional: pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) else: pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype) # pos_emb has size 2T+1 x 1 x H pos_emb = self.positional_encoder( pos, bsz=input.size(1) if self.fast_self_attn else None) # B x T x H -> T x B x H context = emb # Apply dropout to both context and pos_emb context = self.preprocess_layer(context) pos_emb = self.preprocess_layer(pos_emb) for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model context = layer(context, pos_emb, mask_src) # final layer norm context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, { 'context': context, 'src_mask': dec_attn_mask, 'src': input }) return output_dict
def forward(self, input, context, src, input_pos=None, src_lang=None, tgt_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x src_len x d_model mask_src (Tensor) batch_size x src_len Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x src_len """ """ Embedding: batch_size x len_tgt x d_model """ input = input.transpose(0, 1) # T x B emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if not self.late_emb_scale: emb = emb * math.sqrt(self.model_size) mem_len = 0 mems = None extra_context = None if self.use_language_embedding: lang_emb = self.language_embeddings(tgt_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb elif self.language_embedding_type == 'concat': lang_emb = lang_emb.unsqueeze(0).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError if context is not None: mask_src = src.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None qlen = input.size(0) klen = qlen + mem_len # preparing self-attention mask. The input is left aligned so we do not need to add the pad mask dec_attn_mask = torch.triu(emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] # pad_mask = input.eq(onmt.constants.PAD).byte() # L x B # # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) # dec_attn_mask = dec_attn_mask.gt(0) dec_attn_mask = dec_attn_mask.bool() if not self.absolute_position_encoding: # relative positions if not self.learnable_position_encoding: pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) pos_emb = self.positional_encoder(pos, bsz=input.size(1)) pos_emb = self.preprocess_layer(pos_emb) else: range_vec = torch.arange(klen, device=emb.device) range_mat = range_vec.unsqueeze(-1).expand(-1, klen).transpose( 0, 1) distance_mat = range_vec - range_mat.transpose(0, 1) distance_mat.clamp_(-self.max_pos_length, self.max_pos_length).add_( self.max_pos_length) pos_emb = distance_mat # pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype).long() # pos.clamp_(-self.max_pos_length, self.max_pos_length).add_(self.max_pos_length) # pos_emb = pos.unsqueeze(1) else: # absolute positions emb = self.positional_encoder(emb.transpose(0, 1)).transpose(0, 1) pos, pos_emb = None, None dec_attn_mask = dec_attn_mask.squeeze(-1) if self.late_emb_scale: emb = emb * math.sqrt(self.model_size) output = self.preprocess_layer(emb.contiguous()) if self.reversible: # TODO: add src lang and tgt lang to reversible output, coverage = reversible_decoder( self.layer_modules, output, pos_emb, context, dec_attn_mask.squeeze(-1), mask_src, False, None) # incremental variables else: for i, layer in enumerate(self.layer_modules): output, coverage = layer(output, context, pos_emb, dec_attn_mask, mask_src, src_lang=src_lang, tgt_lang=tgt_lang) # if self.checkpointing == 0 or self.training is False: # # output, coverage = layer(output, context, pos_emb, dec_attn_mask, mask_src, # src_lang=src_lang, tgt_lang=tgt_lang) # # else: # output, coverage = checkpoint(create_forward_function(layer), output, context, pos_emb, # dec_attn_mask, # mask_src, src_lang, tgt_lang) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. output = self.postprocess_layer(output) output_dict = { 'hidden': output, 'coverage': coverage, 'context': context } output_dict = defaultdict(lambda: None, output_dict) return output_dict
def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ if self.input_type == "text": mask_src = input.eq(onmt.constants.PAD).unsqueeze( 1) # batch_size x 1 x len_src for broadcasting # apply switchout # if self.switchout > 0 and self.training: # vocab_size = self.word_lut.weight.size(0) # input = switchout(input, vocab_size, self.switchout) emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) else: if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) emb = emb.type_as(input) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(1) time_encoding = self.positional_encoder.get_positional_embeddings(emb) # B x T x H -> T x B x H context = self.preprocess_layer(emb.transpose(0, 1)) for i in range(self.max_layers): layer_vector = torch.LongTensor([i]).to(emb.device) layer_vector = self.layer_embedding(layer_vector).unsqueeze( 0) # 1 x 1 x model_size context = self.universal_layer(context, time_encoding, layer_vector, mask_src) # last layer norm context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, { 'context': context, 'src_mask': mask_src, 'src': input }) if streaming: streaming_state.prev_src_mem_size += sum(input_length.tolist()) streaming_state.prune_source_memory(self.max_memory_size) # streaming_state.update_src_mems(hids, qlen) output_dict['streaming_state'] = streaming_state return output_dict
def forward(self, input, input_lang=None, **kwargs): """ Inputs Shapes: input: batch_size x len_src (wanna tranpose) Outputs Shapes: out: batch_size x len_src x d_model mask_src """ """ Embedding: batch_size x len_src x d_model """ if self.input_type == "text": mask_src = input.eq(onmt.constants.PAD).unsqueeze( 1) # batch_size x len_src x 1 for broadcasting # apply switchout # if self.switchout > 0 and self.training: # vocab_size = self.word_lut.weight.size(0) # input = switchout(input, vocab_size, self.switchout) emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) else: if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input if torch_version >= 1.2: mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(1) # B x T x H -> T x B x H context = emb.transpose(0, 1) context = self.preprocess_layer(context) for i, layer in enumerate(self.layer_modules): context = layer(context, mask_src) # batch_size x len_src x d_model # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. context = self.postprocess_layer(context) output_dict = {'context': context, 'src_mask': mask_src} # return context, mask_src return output_dict
def forward(self, batch, target_mask=None, streaming=False, **kwargs): tgt = batch.get('target_input') tgt_lang = batch.get('target_lang') if streaming: streaming_state = kwargs.get('streaming_state', None) mems = streaming_state.tgt_mems else: mems = None qlen = tgt.size(0) word_emb = embedded_dropout( self.tgt_embedding, tgt, dropout=self.word_dropout if self.training else 0) word_emb.mul_(self.model_size**0.5) if self.use_language_embedding: lang_emb = self.language_embeddings(tgt_lang) # B x H if self.language_embedding_type in ['sum', 'all_sum']: word_emb = word_emb + lang_emb else: raise NotImplementedError mlen = mems[0].size(0) if mems is not None else 0 # total length: memory + current input klen = mlen + qlen # all units having the same attention range if self.same_length: all_ones = word_emb.new_ones(qlen, klen) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = ( torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 else: dec_attn_mask = torch.triu(word_emb.new_ones(qlen, klen), diagonal=1 + mlen).byte()[:, :, None] dec_attn_mask = dec_attn_mask.bool() pos = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.positional_encoder(pos) # Applying dropout output = self.preprocess_layer(word_emb) if streaming: hids = [output] pos_emb = self.preprocess_layer(pos_emb) # FORWARD PASS coverage = None for i, layer in enumerate(self.layer_modules): mems_i = None if mems is None else mems[i] output, coverage = layer( output, None, pos_emb, dec_attn_mask, None, mems=mems_i) # context and context_mask are None if streaming: hids.append(output) # Final normalization output = self.postprocess_layer(output) output_dict = { 'hidden': output, 'coverage': coverage, 'context': None, 'src': None, 'target_mask': target_mask } output_dict = defaultdict(lambda: None, output_dict) # final layer: computing log probabilities logprobs = self.generator[0](output_dict) output_dict['logprobs'] = logprobs if streaming: streaming_state.update_tgt_mems(hids, qlen) output_dict['streaming_state'] = streaming_state return output_dict
def forward_grow(self, input): """ Inputs Shapes: input: batch_size x len_src (wanna tranpose) Outputs Shapes: out: batch_size x len_src x d_model mask_src """ with torch.no_grad(): """ Embedding: batch_size x len_src x d_model """ emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) """ Scale the emb by sqrt(d_model) """ if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if isinstance(emb, tuple): emb = emb[0] emb = self.preprocess_layer(emb) mask_src = input.data.eq(onmt.constants.PAD).unsqueeze( 1) # batch_size x len_src x 1 for broadcasting pad_mask = torch.autograd.Variable( input.data.ne(onmt.constants.PAD)) # batch_size x len_src #~ pad_mask = None context = emb.contiguous() memory_bank = list() for i in range(self.pretrained_point): layer = self.layer_modules[i] context, norm_input = layer( context, mask_src, pad_mask) # batch_size x len_src x d_model if i > 0: # don't keep the norm input of the first layer (a.k.a embedding) memory_bank.append(norm_input) for i in range(self.layers - self.pretrained_point): res_drop_rate = 0.0 if i == 0: res_drop_rate = self.grow_dropout layer = self.layer_modules[self.pretrained_point + i] context, norm_input = layer(context, mask_src, pad_mask, residual_dropout=res_drop_rate ) # batch_size x len_src x d_model memory_bank.append(norm_input) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. context = self.postprocess_layer(context) # make a huge memory bank on the encoder side memory_bank.append(context) memory_bank = torch.stack(memory_bank) return memory_bank, mask_src
def forward(self, input, input_lang=None, **kwargs): """ Inputs Shapes: input: batch_size x len_src (to be transposed) Outputs Shapes: out: batch_size x len_src x d_model mask_src """ """ Embedding: batch_size x len_src x d_model """ if self.input_type == "text": mask_src = input.eq(onmt.constants.PAD).unsqueeze( 1) # batch_size x 1 x len_src for broadcasting # apply switchout # if self.switchout > 0 and self.training: # vocab_size = self.word_lut.weight.size(0) # input = switchout(input, vocab_size, self.switchout) emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) else: if not self.cnn_downsampling: mask_src = input.narrow(2, 0, 1).squeeze(2).eq( onmt.constants.PAD).unsqueeze(1) input = input.narrow(2, 1, input.size(2) - 1) emb = self.audio_trans(input.contiguous().view( -1, input.size(2))).view(input.size(0), input.size(1), -1) emb = emb.type_as(input) else: long_mask = input.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) input = input.narrow(2, 1, input.size(2) - 1) # first resizing to fit the CNN format input = input.view(input.size(0), input.size(1), -1, self.channels) input = input.permute(0, 3, 1, 2) input = self.audio_trans(input) input = input.permute(0, 2, 1, 3).contiguous() input = input.view(input.size(0), input.size(1), -1) # print(input.size()) input = self.linear_trans(input) mask_src = long_mask[:, 0:input.size(1) * 4:4].unsqueeze(1) # the size seems to be B x T ? emb = input mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) emb = emb + lang_emb.unsqueeze(1) # B x T x H -> T x B x H context = emb.transpose(0, 1) context = self.preprocess_layer(context) if self.reversible: # x_1 and x_2 are the same at first for reversible context = torch.cat([context, context], dim=-1) context = ReversibleEncoderFunction.apply(context, self.layer_modules, mask_src) else: for i, layer in enumerate(self.layer_modules): context = layer(context, mask_src) # batch_size x len_src x d_model context = self.postprocess_layer(context) output_dict = {'context': context, 'src_mask': mask_src} # return context, mask_src return output_dict
def forward_grow(self, input, context, src): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x len_src x d_model mask_src (Tensor) batch_size x len_src Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x len_src """ """ Embedding: batch_size x len_tgt x d_model """ with torch.no_grad(): emb = embedded_dropout( self.word_lut, input, dropout=self.word_dropout if self.training else 0) if self.time == 'positional_encoding': emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ emb = self.time_transformer(emb) if isinstance(emb, tuple): emb = emb[0] emb = self.preprocess_layer(emb) mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1) pad_mask_src = torch.autograd.Variable( src.data.ne(onmt.constants.PAD)) len_tgt = input.size(1) mask_tgt = input.data.eq(onmt.constants.PAD).unsqueeze( 1) + self.mask[:len_tgt, :len_tgt] mask_tgt = torch.gt(mask_tgt, 0) output = emb.contiguous() pad_mask_tgt = torch.autograd.Variable( input.data.ne(onmt.constants.PAD)) # batch_size x len_src pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1)) for i in range(self.pretrained_point): layer = self.layer_modules[i] output, coverage = layer( output, context[i], mask_tgt, mask_src, pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model for i in range(self.layers - self.pretrained_point): res_drop_rate = 0.0 if i == 0: res_drop_rate = self.grow_dropout layer = self.layer_modules[self.pretrained_point + i] output, coverage = layer(output, context[self.pretrained_point + i], mask_tgt, mask_src, pad_mask_tgt, pad_mask_src, residual_dropout=res_drop_rate ) # batch_size x len_src x d_model # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. output = self.postprocess_layer(output) return output, coverage
def forward(self, dec_seq, enc_out, src, tgt_lang=None, hid=None, **kwargs): emb = embedded_dropout(self.word_lut, dec_seq, dropout=self.word_dropout if self.training else 0) emb = emb * math.sqrt(self.model_size) if self.use_language_embedding: # print("Using language embedding") lang_emb = self.language_embeddings(tgt_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': dec_emb = emb + lang_emb.unsqueeze(1) elif self.language_embedding_type == 'concat': # replace the bos embedding with the language bos_emb = lang_emb.expand_as(emb[0]) emb[0] = bos_emb lang_emb = lang_emb.unsqueeze(0).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) dec_emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError else: dec_emb = emb if enc_out is not None: if self.encoder_type == "audio": if not self.encoder_cnn_downsampling: mask_src = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD).unsqueeze(1) else: long_mask = src.data.narrow(2, 0, 1).squeeze(2).eq(onmt.constants.PAD) mask_src = long_mask[:, 0: enc_out.size(0) * 4:4].unsqueeze(1) else: mask_src = src.data.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None # if dec_seq.size(0) > 1 and dec_seq.size(1) > 1: # lengths = dec_seq.gt(onmt.constants.PAD).sum(-1) # dec_in = pack_padded_sequence(dec_emb, lengths, batch_first=True, enforce_sorted=False) # dec_out, hid = self.lstm(dec_in, hid) # dec_out = pad_packed_sequence(dec_out, batch_first=True)[0] # else: if self.multilingual_factorized_weights: dec_out, hid = self.lstm(dec_emb, hid, indices=tgt_lang) else: dec_out, hid = self.lstm(dec_emb, hid) lt = dec_seq.size(1) attn_mask = mask_src.expand(-1, lt, -1) if not self.fast_xattention else mask_src.squeeze(1) # dec_out = self.postprocess_layer(dec_out) dec_out = self.preprocess_attn(dec_out) dec_out = dec_out.transpose(0, 1).contiguous() enc_out = enc_out.contiguous() if self.multilingual_factorized_weights: output, coverage = self.multihead_tgt(dec_out, enc_out, enc_out, tgt_lang, tgt_lang, attn_mask) else: output, coverage = self.multihead_tgt(dec_out, enc_out, enc_out, attn_mask) output = (output + dec_out) output = self.postprocess_layer(output) output_dict = defaultdict(lambda: None, {'hidden': output, 'coverage': coverage, 'context': enc_out}) return output_dict
def forward(self, input, input_pos=None, input_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: batch_size x src_len (wanna tranpose) Outputs Shapes: out: batch_size x src_len x d_model mask_src """ """ Embedding: batch_size x src_len x d_model """ bsz_first_input = input input = input.transpose(0, 1) dec_attn_mask = bsz_first_input.eq(onmt.constants.PAD).unsqueeze(1) mem_len = 0 mask_src = input.eq(onmt.constants.PAD).unsqueeze(0) # batch_size x src_len x 1 for broadcasting mems = None emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0) """ Adding language embeddings """ if self.use_language_embedding: assert self.language_embedding is not None # There is no "unsqueeze" here because the input is T x B x H and lang_emb is B x H if self.language_embedding_type in ['sum', 'all_sum']: lang_emb = self.language_embedding(input_lang) # print(lang_emb.size(), emb.size()) emb = emb + lang_emb.unsqueeze(0) if self.unidirectional: qlen = input.size(0) klen = qlen + mem_len attn_mask_src = torch.triu( emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] pad_mask = mask_src mask_src = pad_mask + attn_mask_src # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) mask_src = mask_src.gt(0) if onmt.constants.torch_version >= 1.2: mask_src = mask_src.bool() """ Scale the emb by sqrt(d_model) """ emb = emb * math.sqrt(self.model_size) """ Adding positional encoding """ qlen = input.size(0) klen = qlen + mem_len # Asynchronous positions: 2K+1 positions instead of K+1 if self.unidirectional: pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) else: pos = torch.arange(klen - 1, -klen, -1.0, device=emb.device, dtype=emb.dtype) # pos_emb has size 2T+1 x 1 x H pos_emb = self.positional_encoder(pos, bsz=input.size(1)) if self.learnable_position_encoding: raise NotImplementedError # B x T x H -> T x B x H context = emb # Apply dropout to both context and pos_emb context = self.preprocess_layer(context) pos_emb = self.preprocess_layer(pos_emb) for i, layer in enumerate(self.layer_modules): # src_len x batch_size x d_model context = layer(context, pos_emb, mask_src, src_lang=input_lang) # final layer norm context = self.postprocess_layer(context) output_dict = defaultdict(lambda: None, {'context': context, 'src_mask': dec_attn_mask, 'src': input}) return output_dict
def forward(self, input, context, src, input_pos=None, src_lang=None, tgt_lang=None, streaming=False, **kwargs): """ Inputs Shapes: input: (Variable) batch_size x len_tgt (wanna tranpose) context: (Variable) batch_size x src_len x d_model mask_src (Tensor) batch_size x src_len Outputs Shapes: out: batch_size x len_tgt x d_model coverage: batch_size x len_tgt x src_len """ """ Embedding: batch_size x len_tgt x d_model """ input = input.transpose(0, 1) # T x B emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0) emb = emb * math.sqrt(self.model_size) mem_len = 0 mems = None extra_context = None if self.use_language_embedding: lang_emb = self.language_embeddings(tgt_lang) # B x H or 1 x H if self.language_embedding_type == 'sum': emb = emb + lang_emb elif self.language_embedding_type == 'concat': # replace the bos embedding with the language bos_emb = lang_emb.expand_as(emb[0]) emb[0] = bos_emb lang_emb = lang_emb.unsqueeze(0).expand_as(emb) concat_emb = torch.cat([emb, lang_emb], dim=-1) emb = torch.relu(self.projector(concat_emb)) else: raise NotImplementedError if context is not None: mask_src = src.eq(onmt.constants.PAD).unsqueeze(1) else: mask_src = None qlen = input.size(0) klen = qlen + mem_len # preparing self-attention mask. The input is either left or right aligned dec_attn_mask = torch.triu( emb.new_ones(qlen, klen), diagonal=1 + mem_len).byte()[:, :, None] # pad_mask = input.eq(onmt.constants.PAD).byte() # L x B # # dec_attn_mask = dec_attn_mask + pad_mask.unsqueeze(0) # dec_attn_mask = dec_attn_mask.gt(0) dec_attn_mask = dec_attn_mask.bool() pos = torch.arange(klen - 1, -1, -1.0, device=emb.device, dtype=emb.dtype) pos_emb = self.positional_encoder(pos, bsz=input.size(1)) output = self.preprocess_layer(emb.contiguous()) pos_emb = self.preprocess_layer(pos_emb) for i, layer in enumerate(self.layer_modules): output, coverage, _ = layer(output, context, pos_emb, dec_attn_mask, mask_src, src_lang=src_lang, tgt_lang=tgt_lang) # From Google T2T # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. output = self.postprocess_layer(output) output_dict = {'hidden': output, 'coverage': coverage, 'context': context} output_dict = defaultdict(lambda: None, output_dict) return output_dict
def forward(self, batch, target_mask=None, **kwargs): src = batch.get('source').transpose( 0, 1) # src_len x batch_size -> bsz x src_len tgt = batch.get('target_input').transpose( 0, 1) # len_tgt x batch_size -> bsz x tgt_len src_pos = batch.get('source_pos') tgt_pos = batch.get('target_pos') src_lang = batch.get('source_lang') tgt_lang = batch.get('target_lang') tgt_len = tgt.size(1) src_len = src.size(1) bsz = tgt.size(0) # Embedding stage (and scale the embedding) src_emb = embedded_dropout(self.src_embedding, src, dropout=self.word_dropout if self.training else 0) \ * math.sqrt(self.model_size) tgt_emb = embedded_dropout(self.tgt_embedding, tgt, dropout=self.word_dropout if self.training else 0) \ * math.sqrt(self.model_size) # Add position encoding src_emb = self.time_transformer(src_emb) tgt_emb = self.time_transformer(tgt_emb) if self.use_language_embedding: if self.language_embedding_type in ["sum", "all_sum"]: src_lang_emb = self.language_embeddings(src_lang) src_emb += src_lang_emb.unsqueeze(1) tgt_lang_emb = self.language_embeddings(tgt_lang) tgt_emb += tgt_lang_emb.unsqueeze(1) # concatenate embedding emb = torch.cat([src_emb, tgt_emb], dim=1) # L x batch_size x H # prepare self-attention mask # For the source: we have two different parts # [1 x src_len x batch_size] # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(0).byte() # src_pad_mask = mask_src_src # # Attention from src to target: everything is padded # mask_src_tgt = mask_src_src.new_ones(1, 1, 1).expand(src_len, tgt_len, bsz) # # [src_len x L x batch_size] # mask_src = torch.cat([mask_src_src.expand(src_len, src_len, bsz), mask_src_tgt], dim=1) # mask_src = mask_src.bool() # mask_src_src = src.eq(onmt.constants.PAD).unsqueeze(1).byte() # B x 1 x src_len # mask_src_tgt = mask_src_src.new_ones(bsz, src_len, tgt_len) # bsz x src_len x tgt_len # # mask_src = torch.cat([mask_src_src.expand(bsz, src_len, src_len), mask_src_tgt], dim=-1) # # # For the target: # mask_tgt_tgt = tgt.eq(onmt.constants.PAD).byte().unsqueeze(1) + self.mask[:tgt_len, :tgt_len] # mask_tgt_tgt = torch.gt(mask_tgt_tgt, 0).byte() # bsz x tgt_len x tgt_len # # mask_tgt_src = mask_tgt_tgt.new_zeros(bsz, tgt_len, src_len) + src.eq(onmt.constants.PAD).unsqueeze(1).byte() # mask_tgt = torch.cat([mask_tgt_src, mask_tgt_tgt], dim=-1) # bsz x tgt_len x T # # attn_mask = torch.cat([mask_src, mask_tgt], dim=1).bool() # L x L x batch_size # lets try to use language modeling style # input_seq = torch.cat([src, tgt], dim=-1) # seq_len = input_seq.size(1) # # attn_mask = self.mask[:seq_len, :seq_len] + input_seq.eq(onmt.constants.PAD).byte().unsqueeze(1) # attn_mask = torch.gt(attn_mask, 0).bool() attn_mask = self.gen_mask(src, tgt) output = emb # Applying dropout and tranpose to T x B x H output = self.preprocess_layer(output).transpose(0, 1) # FORWARD PASS coverage = None for i, layer in enumerate(self.layer_modules): output, coverage = layer(output, None, attn_mask, None) # context and context_mask are None # Final normalization output = self.postprocess_layer(output) # extract the "source" and "target" parts of the output context = output[:src_len, :, :] output = output[-tgt_len:, :, :] output_dict = { 'hidden': output, 'coverage': coverage, 'context': context, 'src': src, 'target_mask': target_mask } # final layer: computing log probabilities logprobs = self.generator[0](output_dict) output_dict['logprobs'] = logprobs return output_dict