def forward(self, inputs, input_type='entity'): # pylint: disable=arguments-differ # Use to specified input type. if input_type == 'entity': weight = self.entity_weight element2id = self.entity2id elif input_type == 'predicate': weight = self.predicate_weight element2id = self.predicate2id else: raise Exception( "{} is not a valid input type, use 'entity' or 'predicate'.". format(x)) # Find ids and add new ones if non-existent. max_len = max([len(input) for input in inputs]) for i, input in enumerate(inputs): ids = [] for key in input: if key not in element2id: element2id[key] = len(weight) weight = self.add_new_embedding(weight, input_type) ids.append(element2id[key]) inputs[i] = ids + [0] * (max_len - len(input)) inputs = torch.LongTensor(inputs) # Find embeddings of ids. original_size = inputs.size() inputs = util.combine_initial_dims(inputs) inputs = util.move_to_device(inputs, self.cuda_device) embedded = embedding(inputs, weight, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse) embedded = util.uncombine_initial_dims(embedded, original_size) return self.project(embedded, input_type)
def forward(self, inputs): # pylint: disable=arguments-differ # inputs may have extra dimensions (batch_size, d1, ..., dn, sequence_length), # but embedding expects (batch_size, sequence_length), so pass inputs to # util.combine_initial_dims (which is a no-op if there are no extra dimensions). # Remember the original size. original_size = inputs.size() inputs = util.combine_initial_dims(inputs) embedded = embedding(inputs, self.weight, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse) # Now (if necessary) add back in the extra dimensions. embedded = util.uncombine_initial_dims(embedded, original_size) if self._projection: projection = self._projection for _ in range(embedded.dim() - 2): projection = TimeDistributed(projection) embedded = projection(embedded) return embedded
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model( input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) return util.uncombine_initial_dims(mix, input_ids.size()) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = mix[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) initial_dims = list(input_ids.shape[:-1]) # The embedder may receive an input tensor that has a sequence length longer than can # be fit. In that case, we should expect the wordpiece indexer to create padded windows # of length `self.max_pieces` for us, and have them concatenated into one long sequence. # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." # We can then split the sequence into sub-sequences of that length, and concatenate them # along the batch dimension so we effectively have one huge batch of partial sentences. # This can then be fed into BERT without any sentence length issues. Keep in mind # that the memory consumption can dramatically increase for large batches with extremely # long sentences. needs_split = full_seq_len > self.max_pieces last_window_size = 0 if needs_split: # Split the flattened list by the window size, `max_pieces` split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) # We want all sequences to be the same length, so pad the last sequence last_window_size = split_input_ids[-1].size(-1) padding_amount = self.max_pieces - last_window_size split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) # Now combine the sequences along the batch dimension input_ids = torch.cat(split_input_ids, dim=0) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model( input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) all_encoder_layers = torch.stack(all_encoder_layers) if needs_split: # First, unpack the output embeddings into one long sequence again unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) # Next, select indices of the sequence such that it will result in embeddings representing the original # sentence. To capture maximal context, the indices will be the middle part of each embedded window # sub-sequence (plus any leftover start and final edge windows), e.g., # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start # and final windows with indices [0, 1] and [14, 15] respectively. # Find the stride as half the max pieces, ignoring the special start and end tokens # Calculate an offset to extract the centermost embeddings of each window stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2 stride_offset = stride // 2 + self.start_tokens first_window = list(range(stride_offset)) max_context_windows = [ i for i in range(full_seq_len) if stride_offset - 1 < i % self.max_pieces < stride_offset + stride ] final_window_start = full_seq_len - ( full_seq_len % self.max_pieces) + stride_offset + stride final_window = list(range(final_window_start, full_seq_len)) select_indices = first_window + max_context_windows + final_window initial_dims.append(len(select_indices)) recombined_embeddings = unpacked_embeddings[:, :, select_indices] else: recombined_embeddings = all_encoder_layers # Recombine the outputs of all layers # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) # recombined = torch.cat(combined, dim=2) input_mask = (recombined_embeddings != 0).long() # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) dims = initial_dims if needs_split else input_ids.size() layers = util.uncombine_initial_dims(recombined_embeddings, dims) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(recombined_embeddings)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = recombined_embeddings[:, range_vector, offsets2d] layers = util.uncombine_initial_dims(selected_embeddings, offsets.size()) if self._scalar_mix is not None: return self._scalar_mix(layers, input_mask) elif self.combine_layers == "last": return layers[-1] else: return layers
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, history_encoding: torch.LongTensor = None, turn_encoding: torch.LongTensor = None, scenario_encoding: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) initial_dims = list(input_ids.shape[:-1]) # The embedder may receive an input tensor that has a sequence length longer than can # be fit. In that case, we should expect the wordpiece indexer to create padded windows # of length `self.max_pieces` for us, and have them concatenated into one long sequence. # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." # We can then split the sequence into sub-sequences of that length, and concatenate them # along the batch dimension so we effectively have one huge batch of partial sentences. # This can then be fed into BERT without any sentence length issues. Keep in mind # that the memory consumption can dramatically increase for large batches with extremely # long sentences. needs_split = full_seq_len > self.max_pieces last_window_size = 0 if needs_split: input_ids = self.split_indices(input_ids) if token_type_ids is not None: token_type_ids = self.split_indices(token_type_ids) if history_encoding is not None: history_encoding = self.split_indices(history_encoding) if turn_encoding is not None: turn_encoding = self.split_indices(turn_encoding) if scenario_encoding is not None: scenario_encoding = self.split_indices(scenario_encoding) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) if history_encoding is None: history_encoding = torch.zeros_like(input_ids) if turn_encoding is None: turn_encoding = torch.zeros_like(input_ids) if scenario_encoding is None: scenario_encoding = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, pooled_output = self.bert_model( input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), history_encoding=util.combine_initial_dims(history_encoding), turn_encoding=util.combine_initial_dims(turn_encoding), scenario_encoding=util.combine_initial_dims(scenario_encoding), attention_mask=util.combine_initial_dims(input_mask)) all_encoder_layers = torch.stack(all_encoder_layers) if needs_split: # First, unpack the output embeddings into one long sequence again unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) assert batch_size == 1 and token_type_ids.max() > 0 num_question_tokens = token_type_ids[0].nonzero().size(0) select_indices = self.indices_to_select(full_seq_len, num_question_tokens) initial_dims.append(len(select_indices)) recombined_embeddings = unpacked_embeddings[:, :, select_indices] else: recombined_embeddings = all_encoder_layers # Recombine the outputs of all layers # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) # recombined = torch.cat(combined, dim=2) input_mask = (recombined_embeddings != 0).long() if self._scalar_mix is not None: mix = self._scalar_mix(recombined_embeddings, input_mask) else: mix = recombined_embeddings[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) dims = initial_dims if needs_split else input_ids.size() return util.uncombine_initial_dims(mix, dims) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) zeros = torch.zeros(offsets2d.size(0), 1, dtype=offsets2d.dtype, device=offsets2d.device) offsets2d = torch.cat([zeros, offsets2d], dim=-1) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length + 1) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length + 1) selected_embeddings = mix[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def test_uncombine_initial_dims(self): embedding2d = torch.randn(4 * 10 * 20 * 17 * 5, 12) embedding = util.uncombine_initial_dims(embedding2d, torch.Size((4, 10, 20, 17, 5))) assert list(embedding.size()) == [4, 10, 20, 17, 5, 12]
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) return util.uncombine_initial_dims(mix, input_ids.size()) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector(offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = mix[range_vector, offsets2d] return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. """ # offsets对应indexer里面的,gector使用start法。 # 即offsets记录着每个token的第一个wordpiece在整句wordpiece list中index batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) # 取出batch size initial_dims = list(input_ids.shape[:-1]) # The embedder may receive an input tensor that has a sequence length longer than can # be fit. In that case, we should expect the wordpiece indexer to create padded windows # of length `self.max_pieces` for us, and have them concatenated into one long sequence. # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." # We can then split the sequence into sub-sequences of that length, and concatenate them # along the batch dimension so we effectively have one huge batch of partial sentences. # This can then be fed into BERT without any sentence length issues. Keep in mind # that the memory consumption can dramatically increase for large batches with extremely # long sentences. needs_split = full_seq_len > self.max_pieces last_window_size = 0 if needs_split: # Split the flattened list by the window size, `max_pieces` # 按照最大piece切分输入 split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) # We want all sequences to be the same length, so pad the last sequence # 最后一列根据需要填充 last_window_size = split_input_ids[-1].size(-1) padding_amount = self.max_pieces - last_window_size # 用0填充列 split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) # Now combine the sequences along the batch dimension # 沿着batch维拼接上 input_ids = torch.cat(split_input_ids, dim=0) # 即为attention机制中的pad mask 防止注意力集中在填充的0上面 input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. ''' 模型forward的返回第一个值如下 last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) – Sequence of hidden-states at the output of the last layer of the model. ''' all_encoder_layers = self.bert_model( input_ids=util.combine_initial_dims(input_ids), # 转为二维 attention_mask=util.combine_initial_dims(input_mask), )[0] # 确保是四维 if len(all_encoder_layers[0].shape) == 3: all_encoder_layers = torch.stack(all_encoder_layers) elif len(all_encoder_layers[0].shape) == 2: all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0) if needs_split: # 这个操作是因为输入的seq长度大于maxpiece 现在要做的是首先将其拆分为一个list的多个元素,将截取的句子还原 # First, unpack the output embeddings into one long sequence again 行拆分 列拼接 # 这步做的是把数据拆沿着batch维度分成一个list,有原batch size个元素,将这些元素再拼接,就可以还原原数据 unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) # 然后要做的是将batch size个list按照最后embedding维进行拼接,最终形成一个维度为 1 * batch * full_seq_len * embed unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) # Next, select indices of the sequence such that it will result in embeddings representing the original # sentence. To capture maximal context, the indices will be the middle part of each embedded window # sub-sequence (plus any leftover start and final edge windows), e.g., # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start # and final windows with indices [0, 1] and [14, 15] respectively. # Find the stride as half the max pieces, ignoring the special start and end tokens # Calculate an offset to extract the centermost embeddings of each window # 寻找最能代表文本(即中间位置的跨步stride) stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2 stride_offset = stride // 2 + self.num_start_tokens # 开头的部分 first_window = list(range(stride_offset)) # 选择中间的 stride个wordpiece max_context_windows = [ i for i in range(full_seq_len) if stride_offset - 1 < i % self.max_pieces < stride_offset + stride ] # Lookback what's left, unless it's the whole self.max_pieces window lookback为应该往左边查看多少个token if full_seq_len % self.max_pieces == 0: lookback = self.max_pieces else: lookback = full_seq_len % self.max_pieces # 尾部 final_window_start = full_seq_len - lookback + stride_offset + stride final_window = list(range(final_window_start, full_seq_len)) # 头 + 中间index + 尾 select_indices = first_window + max_context_windows + final_window # 这时候将最后一维加入list中 initial_dims.append(len(select_indices)) # 选择一个句子中的一部分token作为表示 recombined_embeddings = unpacked_embeddings[:, :, select_indices] else: recombined_embeddings = all_encoder_layers # Recombine the outputs of all layers # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) # recombined = torch.cat(combined, dim=2) # mask同上 input_mask = (recombined_embeddings != 0).long() if self._scalar_mix is not None: mix = self._scalar_mix(recombined_embeddings, input_mask) else: mix = recombined_embeddings[-1] # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) dims = initial_dims if needs_split else input_ids.size() return util.uncombine_initial_dims(mix, dims) else: # offsets 在gec_model 中的preprocess函数里面生成的,维度为[batch_size, seq_len] offsets2d = util.combine_initial_dims(offsets) # now offsets is [batch_size, seq_len] # rangevector返回一个tensor,如offsets2d.size(0)=5 返回 [0,1,2,3,4] range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) # 这里是给每个token选择offsets记录的它的wordpiece的idx去代替它 selected_embeddings = mix[range_vector, offsets2d] # return the reshaped tensor of embeddings with shape (d1, ..., dn, orig_sequence_length, embedding_dim) # If original size is 1-d or 2-d, return it as is. # 这里直接返回selected embedings return util.uncombine_initial_dims(selected_embeddings, offsets.size())
def forward(self, tokens: torch.Tensor) -> torch.Tensor: if not ram_has_flag("EXE_ONCE.weighted_embedding"): print("The weighted embedding is working") import sys sys.stdout.flush() ram_set_flag("EXE_ONCE.weighted_embedding") if ram_has_flag("warm_mode", True) or ram_has_flag("weighted_off", True): embedded = embedding( util.combine_initial_dims(tokens), self.weight, padding_idx=self.padding_index, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse, ) embedded = util.uncombine_initial_dims(embedded, tokens.size()) return embedded nbr_tokens, _coeff = self.hull.get_nbr_and_coeff(tokens.view(-1)) # n_words x n_nbrs x dim embedded = embedding( nbr_tokens, self.weight, padding_idx=self.padding_index, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse, ) if not adv_utils.is_adv_mode(): coeff_logit = (_coeff + 1e-6).log() else: last_fw, last_bw = adv_utils.read_var_hook("coeff_logit") # coeff_logit = last_fw + adv_utils.recieve("step") * last_bw norm_last_bw = last_bw / (torch.norm(last_bw, dim=-1, keepdim=True) + 1e-6) coeff_logit = last_fw + adv_utils.recieve("step") * norm_last_bw coeff_logit = coeff_logit - coeff_logit.max(1, keepdim=True)[0] coeff_logit.requires_grad_() adv_utils.register_var_hook("coeff_logit", coeff_logit) coeff = F.softmax(coeff_logit, dim=1) # if adv_utils.is_adv_mode(): # last_coeff = F.softmax(last_fw, dim=1) # new_points = (embedded[:20] * coeff[:20].unsqueeze(-1)).sum(-2) # old_points = (embedded[:20] * last_coeff[:20].unsqueeze(-1)).sum(-2) # step_size = (new_points - old_points).norm(dim=-1).mean() # inner_size = (embedded[:20, 1:] - embedded[:20, :1]).norm(dim=-1).mean() # print(round(inner_size.item(), 3), round(step_size.item(), 3)) embedded = (embedded * coeff.unsqueeze(-1)).sum(-2) embedded = embedded.view(*tokens.size(), self.weight.size(1)) if adv_utils.is_adv_mode(): if ram_has_flag("adjust_point"): raw_embedded = embedding( tokens, self.weight, padding_idx=self.padding_index, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse, ) delta = embedded.detach() - raw_embedded.detach() embedded = raw_embedded + delta return embedded