def translate(self, src_data_iter, tgt_data_iter, structure_iter, batch_size, out_file=None): data = build_dataset(self.fields, src_data_iter, tgt_data_iter, None, structure_iter,None, None, use_filter_pred=False) # for line in data: # print(line.__dict__) {src: , indices: structure: } def sort_translation(indices, translation): ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=True, sort_within_batch=True, shuffle=True) start_time = time.time() print("Begin decoding ...") batch_count = 0 all_translation = [] for batch in data_iter: ''' batch [torchtext.data.batch.Batch of size 30] [.src]:('[torch.LongTensor of size 4x30]', '[torch.LongTensor of size 30]') [.indices]:[torch.LongTensor of size 30] [.structure]:[torch.LongTensor of size 30x4x4] ''' hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) batch_transtaltion = [] for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores): src_words = self.build_tokens(src_idx_seq, side='src') src = ' '.join(src_words) tran_words = self.build_tokens(tran_idx_seq, side='tgt') tran = ' '.join(tran_words) batch_transtaltion.append(tran) print("SOURCE: " + src + "\nOUTPUT: " + tran + "\n") for index, tran in zip(batch.indices.data, batch_transtaltion): while (len(all_translation) <= index): all_translation.append("") all_translation[index] = tran batch_count += 1 print("batch: " + str(batch_count) + "...") if out_file is not None: for tran in all_translation: out_file.write(tran + '\n') print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None): # data每次产生一个eaxmple, 包含example.indice, example.src data = build_dataset(self.fields, src_data_iter=src_data_iter, tgt_data_iter=tgt_data_iter, use_filter_pred=False) def sort_translation(indices, translation): # indices是一维张量,translation是一维数组 ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) start_time = time.time() print("Begin decoding ...") idx = 0, # 此处的batch中的src每行长度不对齐 for batch in data_iter: # batch.src[0]: (27, batch_size), batch.src[1]: (27, ... ,...) # hyps尺寸为(batch_size, 4)的arry, scores长度为batch_size的一维数组 # 可以看出最终每句话均翻译为4个单词 # 下面代码使用batch的时候,并没有迭代,而是直接取值 hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) transtaltion = [] for idx_seq, score in zip(hyps, scores): words = self.build_tokens(idx_seq, side='tgt') tran = ' '.join(words) transtaltion.append(tran) if out_file is not None: transtaltion = sort_translation(batch.indices.data - idx, transtaltion) for tran in transtaltion: out_file.write(tran + '\n') idx += len(batch) print("sents " + str(idx) + "...") print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None): data = build_dataset(self.fields, src_data_iter=src_data_iter, tgt_data_iter=tgt_data_iter, use_filter_pred=False) def sort_translation(indices, translation): ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) start_time = time.time() print("Begin decoding ...") idx = 0 for batch in data_iter: hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) transtaltion = [] for idx_seq, score in zip(hyps, scores): words = self.build_tokens(idx_seq, side='tgt') tran = ' '.join(words) transtaltion.append(tran) if out_file is not None: transtaltion = sort_translation(batch.indices.data - idx, transtaltion) for tran in transtaltion: out_file.write(tran + '\n') idx += len(batch) print("sents " + str(idx) + "...") print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None): data = build_dataset(self.fields, src_data_iter=src_data_iter, tgt_data_iter=tgt_data_iter, use_filter_pred=False) def sort_translation(indices, translation): ordered_transalation = [None] * len(translation) for i, index in enumerate(indices): ordered_transalation[index] = translation[i] return ordered_transalation if self.cuda: cur_device = "cuda" else: cur_device = "cpu" # sort=True sort_within_batch=True shuffle=True data_iter = OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=False, shuffle=False) start_time = time.time() print("Begin decoding ...") batch_count = 0 all_translation = [] for batch in data_iter: hyps, scores = self.translate_batch(batch) assert len(batch) == len(hyps) batch_transtaltion = [] for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores): # src_words = self.build_tokens(src_idx_seq, side='src') # src = ' '.join(src_words) tran_words = self.build_tokens(tran_idx_seq, side='tgt') tran = ' '.join(tran_words) batch_transtaltion.append(tran) print("SOURCE: " + "Three-modal" + "\nOUTPUT: " + tran + "\n")#src for index, tran in zip(batch.indices.data, batch_transtaltion):# why my batch have inidices while (len(all_translation) <= index): all_translation.append("") all_translation[index] = tran batch_count += 1 print("batch: " + str(batch_count) + "...") if out_file is not None: for tran in all_translation: out_file.write(tran + '\n') print('Decoding took %.1f minutes ...'%(float(time.time() - start_time) / 60.))