def build_save_dataset(corpus_type, fields, src_corpus, tgt_corpus, savepath, args): """ Building and saving the dataset """ assert corpus_type in ['train', 'dev', 'test'] dataset = inputters.build_dataset(fields, data_type='text', src_path=src_corpus, tgt_path=tgt_corpus, src_dir='', src_seq_length=args.max_src_len, tgt_seq_length=args.max_tgt_len, src_seq_length_trunc=0, tgt_seq_length_trunc=0, dynamic_dict=True) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] for i in range(len(dataset)): if i % 500 == 0: print(i) setattr(dataset.examples[i], 'graph', myutils.str2graph(dataset.examples[i].src)) pt_file = "{:s}/{:s}.pt".format(savepath, corpus_type) # torch.save(dataset, pt_file) with open(pt_file, 'wb') as f: pickle.dump(dataset, f) return [pt_file]
def parrel_func(example, have_3d): # myutils.str2graph(dataset.examples[i].src) graph = None try: smile = example.src graph = myutils.str2graph(smile, have_3d) except Exception as e: print(e) # raise e setattr(example, 'graph', graph) return example
def build_save_dataset(corpus_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src_corpus = opt.train_src tgt_corpus = opt.train_tgt else: src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, opt) # For data_type == 'img' or 'audio', currently we don't do # preprocess sharding. We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. dataset = inputters.build_dataset( fields, "text", src_path=src_corpus, tgt_path=tgt_corpus, src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) for i in range(len(dataset)): if i % 500 == 0: print(i) setattr(dataset.examples[i], 'graph', myutils.str2graph(dataset.examples[i].src)) # torch.save(dataset, pt_file) with open(pt_file, 'wb') as f: pickle.dump(dataset, f) return [pt_file]
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters. \ build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, image_channel_size=self.image_channel_size) # add the graph field for i in range(len(data)): if i % 500 == 0: print(i) setattr(data.examples[i], 'graph', myutils.str2graph(data.examples[i].src)) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: batch.graph = myutils.pad_for_graph(batch.graph, torch.max(batch.src[1]).item()) batch_data = self.translate_batch(batch, data, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.log_probs_out_file is not None: self.log_probs_out_file.write('\n'.join([ str(t.item()) for t in trans.pred_scores[:self.n_best] ]) + '\n') self.log_probs_out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, opt): """ Divide src_corpus and tgt_corpus into smaller multiples src_copus and tgt corpus files, then build shards, each shard will have opt.shard_size samples except last shard. The reason we do this is to avoid taking up too much memory due to sucking in a huge corpus file. """ with codecs.open(src_corpus, "r", encoding="utf-8") as fsrc: with codecs.open(tgt_corpus, "r", encoding="utf-8") as ftgt: logger.info("Reading source and target files: %s %s." % (src_corpus, tgt_corpus)) src_data = fsrc.readlines() tgt_data = ftgt.readlines() num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): logger.info("Building shard %d." % index) dataset = inputters.build_dataset( fields, opt.data_type, src_path=src, tgt_path=tgt_list[index], src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (index, corpus_type, pt_file)) for i in range(len(dataset)): if i % 500 == 0: print(i) setattr(dataset.examples[i], 'graph', myutils.str2graph(dataset.examples[i].src)) # torch.save(dataset, pt_file) with open(pt_file, 'wb') as f: pickle.dump(dataset, f) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[index]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list
import torch import dgl import pickle import torchtext import onmt.myutils as myutils # dataset = torch.load('data/seqdata.train.0.pt') # g = myutils.str2graph(dataset.examples[0].src) # setattr(dataset.examples[0], 'graph', g) # print(g) # torch.save(dataset,'g.pt') # a = torch.load('g.pt') # print(a.examples[0].graph) # print(a) dataset = torch.load('data/seqdata.train.0.pt') g = myutils.str2graph(dataset.examples[0].src) setattr(dataset.examples[0], 'graph', g) with open('g.pt', 'wb') as f: pickle.dump(dataset, f) with open('g.pt', 'rb') as f: a = pickle.load(f) print(a)