def mask(self, tensor): _, half_dim, dim = tensor.size() if self.onnx_trace: # triu and tril are not supported in onnx a = torch._dim_arange(tensor, 2).unsqueeze(0).repeat(half_dim, 1) b = torch._dim_arange(tensor, 1).unsqueeze(1).repeat(1, dim) mask = (a > b + half_dim).float() + (a < b).float() mask = torch.where(mask > 0, torch.Tensor([0]).type_as(tensor), torch.Tensor([float("-Inf")]).type_as(tensor)) else: ones = tensor.new_ones(half_dim, dim).bool() mask = ones.triu(half_dim + 1) + ones.tril(-1) mask = utils.fill_with_neg_inf(tensor.new( mask.size())).masked_fill_(mask, 0) return mask
def make_positions(tensor, padding_idx, left_pad, onnx_trace=False): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ if onnx_trace: range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1 mask = tensor.ne(padding_idx) positions = range_buf.expand_as(tensor) if left_pad: positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) return positions * mask.long() + positions * (1 - mask.long()) max_pos = padding_idx + 1 + tensor.size(1) if not hasattr(make_positions, 'range_buf'): make_positions.range_buf = tensor.new() make_positions.range_buf = make_positions.range_buf.type_as(tensor) if make_positions.range_buf.numel() < max_pos: torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) mask = tensor.ne(padding_idx) positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) if left_pad: positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) return tensor.clone().masked_scatter_(mask, positions[mask])
def make_positions(tensor, padding_idx, left_pad, onnx_trace=False): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ if onnx_trace: range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1 mask = tensor.ne(padding_idx) positions = range_buf.expand_as(tensor) if left_pad: positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) return positions * mask.long() + padding_idx * (1 - mask.long()) max_pos = padding_idx + 1 + tensor.size(1) #if not hasattr(make_positions, 'range_buf'): make_positions.range_buf = tensor.new() make_positions.range_buf = make_positions.range_buf.type_as(tensor) if make_positions.range_buf.numel() < max_pos: torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) mask = tensor.ne(padding_idx) positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) if left_pad: positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) return tensor.clone().masked_scatter_(mask, positions[mask])
def buffered_past_mask(self, tensor): dim = tensor.size(0) if self.onnx_trace: a = torch._dim_arange(tensor, 0).unsqueeze(0).repeat(dim, 1) b = torch._dim_arange(tensor, 0).unsqueeze(1).repeat(1, dim) past_mask = a < b past_mask_neg_inf = torch.where(past_mask, torch.Tensor([float("-Inf")]), torch.Tensor([0])).type_as(tensor) return past_mask_neg_inf if not hasattr( self, '_past_mask' ) or self._past_mask is None or self._past_mask.device != tensor.device: self._past_mask = torch.tril( utils.fill_with_neg_inf(tensor.new(dim, dim)), -1) if self._past_mask.size(0) < dim: self._past_mask = torch.tril( utils.fill_with_neg_inf(self._past_mask.resize_(dim, dim)), -1) return self._past_mask[:dim, :dim]
def _with_sentence_boundaries( self, input: torch.Tensor): """ Args: input: the sentence Tensor it's bs * seq_len * num_chars in case of char input and bs*seq_len in case of token input Returns: tuple, 1) processed input, 2) tensor mask for the eos position of each sentence, None if did not add eos """ if not self.add_bos and not self.add_eos: return input, None zero_block = input.new(0, 0) block_size = (input.size(0), 1, input.size(2)) if self.char_inputs else (input.size(0), 1) bos_block = torch.full(block_size, self.eos_idx).type_as(input) if self.add_bos else zero_block pad_block = torch.full(block_size, self.padding_idx).type_as(input) if self.add_eos else zero_block # add eos in the beginning and pad to the end of the sentence input = torch.cat([bos_block, input, pad_block], dim=1) first_pads = None # if not add_eos, then first_pads is not valid, set to None if self.add_eos: index_block = input[:, :, 0] if self.char_inputs else input padding_mask = index_block.eq(self.padding_idx) num_pads = padding_mask.long().sum(dim=1, keepdim=True) max_len = input.size(1) # index of the first pad if self.onnx_trace: first_pads = torch._dim_arange(input, 1).type_as(input).view(1, -1).\ repeat(input.size(0), 1).eq(max_len - num_pads) eos_indices = first_pads if self.char_inputs: eos_indices = eos_indices.unsqueeze(2).repeat(1, 1, input.size(-1)) input = torch.where(eos_indices, torch.Tensor([self.eos_idx]).type_as(input), input) else: first_pads = buffered_arange(max_len).type_as(input).view(1, -1).\ expand(input.size(0), -1).eq(max_len - num_pads) eos_indices = first_pads if self.char_inputs: eos_indices = eos_indices.unsqueeze(2).expand_as(input) input[eos_indices] = self.eos_idx return input, first_pads
def make_positions(X, padding_idx, left_pad, onnx_trace=False): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ max_seq_len = X.shape[1] # torch._dim_arange is a temporary hack to allow tracing of arange like # constructs with dynamic bounds on arange. Normal arange is not traceable # because it does not take any tensor inputs; if the range you need is # based on another tensor, calling this function directly will preserve # tracing. Get rid of this when arange can directly take tensors for # bounds (so that it can be traced directly). if onnx_trace: range_buf = torch._dim_arange(like=X, dim=1) + padding_idx + 1 mask = X.ne(padding_idx) positions = range_buf.expand_as(X) if left_pad: offsets = max_seq_len - mask.long().sum(dim=1).unsqueeze(1) positions = positions - offsets return positions * mask.long() + padding_idx * (1 - mask.long()) max_pos = padding_idx + 1 + X.size(1) # Function attributes are used for caching if not hasattr(make_positions, 'range_buf'): make_positions.range_buf = X.new() make_positions.range_buf = make_positions.range_buf.type_as(X) if make_positions.range_buf.numel() < max_pos: torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) mask = X.ne(padding_idx) positions = make_positions.range_buf[:X.size(1)].expand_as(X) if left_pad: offsets = max_seq_len - mask.long().sum(dim=1).unsqueeze(1) positions = positions - offsets return X.clone().masked_scatter_(mask, positions[mask])
def forward(self, input): return torch._dim_arange(input, 1)
def forward(self, input_token, target_token, timestep, *inputs): """ Decoder step inputs correspond one-to-one to encoder outputs. """ log_probs_per_model = [] state_outputs = [] next_state_input = len(self.models) # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module if vocab_reduction_module is not None: possible_translation_tokens = inputs[len(self.models)] next_state_input += 1 else: possible_translation_tokens = None for i, model in enumerate(self.models): encoder_output = inputs[i] prev_hiddens = [] prev_cells = [] for _ in range(len(model.decoder.layers)): prev_hiddens.append(inputs[next_state_input]) prev_cells.append(inputs[next_state_input + 1]) next_state_input += 2 prev_input_feed = inputs[next_state_input].view(1, -1) next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = int(encoder_output.size()[0]) src_length = torch.LongTensor(np.array([src_length_int])) # notional, not actually used for decoder computation src_tokens = torch.LongTensor(np.array([[0] * src_length_int])) src_embeddings = encoder_output.new_zeros(encoder_output.shape) encoder_out = ( encoder_output, prev_hiddens, prev_cells, src_length, src_tokens, src_embeddings, ) # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() # placeholder incremental_state = {} # cache previous state inputs utils.set_incremental_state( model.decoder, incremental_state, "cached_state", (prev_hiddens, prev_cells, prev_input_feed), ) decoder_output = model.decoder( input_token.view(1, 1), encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, ) logits, _, _ = decoder_output log_probs = F.log_softmax(logits, dim=2) log_probs_per_model.append(log_probs) (next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state( model.decoder, incremental_state, "cached_state") for h, c in zip(next_hiddens, next_cells): state_outputs.extend([h, c]) state_outputs.append(next_input_feed) average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True) if possible_translation_tokens is not None: reduced_indices = torch.zeros(self.vocab_size).long().fill_( self.unk_token) # ONNX-exportable arange (ATen op) possible_translation_token_range = torch._dim_arange( like=possible_translation_tokens, dim=0) reduced_indices[ possible_translation_tokens] = possible_translation_token_range reduced_index = reduced_indices.index_select(dim=0, index=target_token) score = average_log_probs.view( (-1, )).index_select(dim=0, index=reduced_index) else: score = average_log_probs.view( (-1, )).index_select(dim=0, index=target_token) word_reward = self.word_rewards.index_select(0, target_token) score += word_reward self.input_names = ["prev_token", "target_token", "timestep"] for i in range(len(self.models)): self.input_names.append(f"fixed_input_{i}") if possible_translation_tokens is not None: self.input_names.append("possible_translation_tokens") outputs = [score] self.output_names = ["score"] for i in range(len(self.models)): self.output_names.append(f"fixed_input_{i}") outputs.append(inputs[i]) if possible_translation_tokens is not None: self.output_names.append("possible_translation_tokens") outputs.append(possible_translation_tokens) for i, state in enumerate(state_outputs): outputs.append(state) self.output_names.append(f"state_output_{i}") self.input_names.append(f"state_input_{i}") return tuple(outputs)
def forward(self, x): return torch._dim_arange(x, 1)
def reverse_sort(x: torch.Tensor, dim: int) -> torch.Tensor: new_indices = torch.empty_like(x) new_size = [1] * dim + [x.size(dim)] + [1] * (x.ndimension() - dim - 1) arange = torch._dim_arange(x, dim=dim).reshape(new_size).expand_as(x) new_indices.scatter_(dim=dim, index=x, src=arange) return new_indices