def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx): in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())] out_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())] full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list, padding_idx) mask_inputs = [[len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels] # generate labels masked_tgt_masks = [] for mask_input in mask_inputs: mask_label = [] for beam_size in mask_input[1:-1]: # HACK 1:-1 mask_label += [0] + [1 for _ in range(beam_size)] masked_tgt_masks.append( mask_label + [0 for _ in range(out_seq_len - len(mask_label))]) mask_ins_targets = [ mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] for mask_input in mask_inputs ] # transform to tensor masked_tgt_masks = torch.tensor(masked_tgt_masks, device=out_tokens.device).bool() mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
def _get_del_targets(in_tokens, out_tokens, padding_idx): libnat = load_libnat() out_seq_len = out_tokens.size(1) with torch.cuda.device_of(in_tokens): in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx ) word_del_targets = [b[-1] for b in full_labels] word_del_targets = [ labels + [0 for _ in range(out_seq_len - len(labels))] for labels in word_del_targets ] # transform to tensor word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) return word_del_targets
def _get_del_targets(in_tokens, out_tokens, padding_idx): try: from fairseq import libnat except ImportError as e: import sys sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") raise e out_seq_len = out_tokens.size(1) in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx ) word_del_targets = [b[-1] for b in full_labels] word_del_targets = [ labels + [0 for _ in range(out_seq_len - len(labels))] for labels in word_del_targets ] # transform to tensor word_del_targets = torch.tensor(word_del_targets) return word_del_targets
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): libnat = load_libnat() in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) with torch.cuda.device_of(in_tokens): in_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())] out_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())] full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list, padding_idx) word_del_targets = [b[-1] for b in full_labels] word_del_targets = [ labels + [0 for _ in range(out_seq_len - len(labels))] for labels in word_del_targets ] mask_inputs = [[len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels] mask_ins_targets = [ mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] for mask_input in mask_inputs ] # transform to tensor mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) return word_del_targets, mask_ins_targets
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None): B = in_tokens.size(0) T = in_tokens.size(1) V = vocab_size with torch.cuda.device_of(in_tokens): in_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())] out_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())] full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list, padding_idx) insert_labels = [a[:-1] for a in full_labels] # numericalize1 insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float() insert_index, insert_labels = zip(*[ (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau)) for i, labels in enumerate(insert_labels) for j, label in enumerate(labels[1:-1]) for k, w in enumerate(label) ]) # HACK 1:-1 insert_index, insert_labels = [ torch.tensor(list(a), device=in_tokens.device) for a in [insert_index, insert_labels] ] insert_label_tensors.scatter_(0, insert_index.long(), insert_labels) insert_label_tensors = insert_label_tensors.view(B, T - 1, V) return insert_label_tensors
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None): """Fix dtype error when tau != None. otherwise same as original. """ try: from fairseq import libnat except ImportError as e: import sys sys.stderr.write( "ERROR: missing libnat. run `pip install --editable .`\n") raise e B = in_tokens.size(0) T = in_tokens.size(1) V = vocab_size with torch.cuda.device_of(in_tokens): in_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())] out_tokens_list = [[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())] full_labels = libnat.suggested_ed2_path(in_tokens_list, out_tokens_list, padding_idx) insert_labels = [a[:-1] for a in full_labels] # numericalize1 insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float() insert_index, insert_labels = zip(*[ (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau)) for i, labels in enumerate(insert_labels) for j, label in enumerate(labels[1:-1]) for k, w in enumerate(label) ]) # HACK 1:-1 insert_index, insert_labels = [ torch.tensor(list(a), device=in_tokens.device) for a in [insert_index, insert_labels] ] insert_label_tensors.scatter_(0, insert_index.long(), insert_labels.type_as(insert_label_tensors)) insert_label_tensors = insert_label_tensors.view(B, T - 1, V) return insert_label_tensors
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): try: from fairseq import libnat except ImportError as e: import sys sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') raise e in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) with torch.cuda.device_of(in_tokens): in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx ) mask_inputs = [ [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels ] # generate labels masked_tgt_masks = [] for mask_input in mask_inputs: mask_label = [] for beam_size in mask_input[1:-1]: # HACK 1:-1 mask_label += [0] + [1 for _ in range(beam_size)] masked_tgt_masks.append( mask_label + [0 for _ in range(out_seq_len - len(mask_label))] ) mask_ins_targets = [ mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] for mask_input in mask_inputs ] # transform to tensor masked_tgt_masks = torch.tensor( masked_tgt_masks, device=out_tokens.device ).bool() mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tgt_dict, tau=None): try: from fairseq import libnat except ImportError as e: import sys sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') raise e B = in_tokens.size(0) T = in_tokens.size(1) V = vocab_size with torch.cuda.device_of(in_tokens): in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx ) insert_labels = [a[:-1] for a in full_labels] # numericalize1 insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float() insert_index, insert_labels = zip( *[ (w + (j + i * (T - 1)) * V, neg_scorer(w, tgt_dict, tau)) for i, labels in enumerate(insert_labels) for j, label in enumerate(labels[1:-1]) for k, w in enumerate(label) ] ) # HACK 1:-1 insert_index, insert_labels = [ torch.tensor(list(a), device=in_tokens.device) for a in [insert_index, insert_labels] ] insert_label_tensors.scatter_(0, insert_index.long(), insert_labels) insert_label_tensors = insert_label_tensors.view(B, T - 1, V) insert_label_tensors = insert_label_tensors / torch.sum(insert_label_tensors, -1).unsqueeze(-1) # NegativeWeightScorerFreqWeight return insert_label_tensors
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): try: from fairseq import libnat except ImportError as e: import sys sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n') raise e in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) with torch.cuda.device_of(in_tokens): in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx ) word_del_targets = [b[-1] for b in full_labels] word_del_targets = [ labels + [0 for _ in range(out_seq_len - len(labels))] for labels in word_del_targets ] mask_inputs = [ [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels ] mask_ins_targets = [ mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] for mask_input in mask_inputs ] # transform to tensor mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) return word_del_targets, mask_ins_targets