def build_fields(ds: Dataset): # initialize the frequency counter counters = defaultdict(Counter) _src_vocab, _src_vocab_size = inputter._load_vocab(ds.vocab.source, 'src', counters) _tgt_vocab, _tgt_vocab_size = inputter._load_vocab(ds.vocab.target, 'tgt', counters) # initialize fields src_nfeats, tgt_nfeats = 0, 0 # do not support word features for now fields = inputter.get_fields(defaults.vocabulary["data_type"], src_nfeats, tgt_nfeats) return inputter._build_fields_vocab(fields, counters, **defaults.vocabulary)
def maybe_load_vocab(corpus_type, counters, opt): src_vocab = None tgt_vocab = None existing_fields = None if corpus_type == "train": if opt.src_vocab != "": try: logger.info("Using existing vocabulary...") existing_fields = torch.load(opt.src_vocab) except torch.serialization.pickle.UnpicklingError: logger.info("Building vocab from text file...") src_vocab, src_vocab_size = _load_vocab( opt.src_vocab, "src", counters, opt.src_words_min_frequency) if opt.tgt_vocab != "": tgt_vocab, tgt_vocab_size = _load_vocab( opt.tgt_vocab, "tgt", counters, opt.tgt_words_min_frequency) return src_vocab, tgt_vocab, existing_fields
def build_dynamic_fields(opts, src_specials=None, tgt_specials=None): """Build fields for dynamic, including load & build vocab.""" fields = _get_dynamic_fields(opts) counters = defaultdict(Counter) logger.info("Loading vocab from text file...") _src_vocab, _src_vocab_size = _load_vocab( opts.src_vocab, 'src', counters, min_freq=opts.src_words_min_frequency) if opts.src_feats_vocab: for feat_name, filepath in opts.src_feats_vocab.items(): _, _ = _load_vocab(filepath, feat_name, counters, min_freq=0) if opts.tgt_vocab: _tgt_vocab, _tgt_vocab_size = _load_vocab( opts.tgt_vocab, 'tgt', counters, min_freq=opts.tgt_words_min_frequency) elif opts.share_vocab: logger.info("Sharing src vocab to tgt...") counters['tgt'] = counters['src'] else: raise ValueError("-tgt_vocab should be specified if not share_vocab.") logger.info("Building fields with vocab in counters...") fields = _build_fields_vocab(fields, counters, 'text', opts.share_vocab, opts.vocab_size_multiple, opts.src_vocab_size, opts.src_words_min_frequency, opts.tgt_vocab_size, opts.tgt_words_min_frequency, src_specials=src_specials, tgt_specials=tgt_specials) return fields
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': counters = defaultdict(Counter) srcs = opt.train_src tgts = opt.train_tgt ids = opt.train_ids else: srcs = [opt.valid_src] tgts = [opt.valid_tgt] ids = [None] logger.info(opt) for src, tgt, maybe_id in zip(srcs, tgts, ids): logger.info("Reading source and target files: %s %s." % (src, tgt)) src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) dataset_paths = [] if (corpus_type == "train" or opt.filter_valid) and tgt is not None: filter_pred = partial(inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None if corpus_type == "train": existing_fields = None if opt.src_vocab != "": try: logger.info("Using existing vocabulary...") existing_fields = torch.load(opt.src_vocab) except torch.serialization.pickle.UnpicklingError: logger.info("Building vocab from text file...") src_vocab, src_vocab_size = _load_vocab( opt.src_vocab, "src", counters, opt.src_words_min_frequency) else: src_vocab = None if opt.tgt_vocab != "": tgt_vocab, tgt_vocab_size = _load_vocab( opt.tgt_vocab, "tgt", counters, opt.tgt_words_min_frequency) else: tgt_vocab = None for i, (src_shard, tgt_shard) in enumerate(shard_pairs): assert len(src_shard) == len(tgt_shard) logger.info("Building shard %d." % i) # @memray: to be different from normal datasets dataset = inputters.str2dataset[opt.data_type]( fields, readers=([src_reader, tgt_reader] if tgt_reader else [src_reader]), data=([("src", src_shard), ("tgt", tgt_shard)] if tgt_reader else [("src", src_shard)]), dirs=([opt.src_dir, None] if tgt_reader else [opt.src_dir]), sort_key=inputters.str2sortkey[opt.data_type], filter_pred=filter_pred) if corpus_type == "train" and existing_fields is None: for ex in dataset.examples: for name, field in fields.items(): try: f_iter = iter(field) except TypeError: f_iter = [(name, field)] all_data = [getattr(ex, name, None)] else: all_data = getattr(ex, name) for (sub_n, sub_f), fd in zip(f_iter, all_data): has_vocab = (sub_n == 'src' and src_vocab is not None) or \ (sub_n == 'tgt' and tgt_vocab is not None) if (hasattr(sub_f, 'sequential') and sub_f.sequential and not has_vocab): val = fd if opt.data_type == 'keyphrase' and sub_n == 'tgt': # in this case, val is a list of phrases (list of strings (words)) for v in val: counters[sub_n].update(v) else: counters[sub_n].update(val) if maybe_id: shard_base = corpus_type + "_" + maybe_id else: shard_base = corpus_type data_path = "{:s}.{:s}.{:d}.pt".\ format(opt.save_data, shard_base, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s. %d examples" % (i, corpus_type, data_path, len(dataset.examples))) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() if corpus_type == "train": vocab_path = opt.save_data + '.vocab.pt' if existing_fields is None: fields = _build_fields_vocab( fields, counters, opt.data_type, opt.share_vocab, opt.vocab_size_multiple, opt.src_vocab_size, opt.src_words_min_frequency, opt.tgt_vocab_size, opt.tgt_words_min_frequency) else: fields = existing_fields torch.save(fields, vocab_path)
def build_save_dataset(corpus_type, fields, src_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': counters = defaultdict(Counter) srcs = opt.train_src ids = opt.train_ids else: srcs = [opt.valid_src] ids = [None] for src, maybe_id in zip(srcs, ids): logger.info("Reading source files: %s." % src) # src_shards = split_corpus(src, opt.shard_size) src_shards = split_corpus(src, 0) dataset_paths = [] # if (corpus_type == "train" or opt.filter_valid): # filter_pred = partial( # inputters.filter_example, use_src_len=opt.data_type == "text", # max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) # else: filter_pred = None if corpus_type == "train": existing_fields = None if opt.src_vocab != "": try: logger.info("Using existing vocabulary...") existing_fields = torch.load(opt.src_vocab) except torch.serialization.pickle.UnpicklingError: logger.info("Building vocab from text file...") src_vocab, src_vocab_size = _load_vocab( opt.src_vocab, "src", counters, opt.src_words_min_frequency) else: src_vocab = None for i, _src_shard in enumerate(src_shards): # not considered shard logger.info("Building shard %d." % i) src_shard = [] for j, line in enumerate(_src_shard): if len(line.strip().split("\t")) == 6: src_shard.append(line) _id = [line.strip().split("\t")[0] for line in src_shard[1:]] # _id = [i for i, line in enumerate(src_shard[1:], 1)] sent1 = [line.strip().split("\t")[3] for line in src_shard[1:]] sent2 = [line.strip().split("\t")[4] for line in src_shard[1:]] prelogit1 = [0.0 for _ in src_shard[1:]] prelogit2 = [0.0 for _ in src_shard[1:]] label = [] for line in src_shard[1:]: token = line.strip().split("\t")[5] if token in ["Good", "entailment", "1", 1]: label.append(1) else: label.append(0) dataset = inputters.Dataset( fields, readers=([ src_reader, src_reader, src_reader, src_reader, src_reader, src_reader ]), data=([("id", _id), ("sent1", sent1), ("sent2", sent2), ("label", label), ("prelogit1", prelogit1), ("prelogit2", prelogit2)]), # data=([("src", src_shard), ("tgt", tgt_shard)] # if tgt_reader else [("src", src_shard)]), dirs=([None, None, None, None, None, None]), sort_key=inputters.str2sortkey[opt.data_type], filter_pred=filter_pred) if corpus_type == "train" and existing_fields is None: for ex in dataset.examples: for name, field in fields.items(): if name in ["label", "id", "prelogit1", "prelogit2"]: continue try: f_iter = iter(field) except TypeError: f_iter = [(name, field)] all_data = [getattr(ex, name, None)] else: all_data = getattr(ex, name) for (sub_n, sub_f), fd in zip(f_iter, all_data): has_vocab = (sub_n == 'src' and src_vocab) if (hasattr(sub_f, 'sequential') and sub_f.sequential and not has_vocab): val = fd counters[sub_n].update(val) if maybe_id: shard_base = corpus_type + "_" + maybe_id else: shard_base = corpus_type data_path = "{:s}.{:s}.{:d}.pt".\ format(opt.save_data, shard_base, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, shard_base, data_path)) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() if corpus_type == "train": vocab_path = opt.save_data + '.vocab.pt' if existing_fields is None: fields = _build_fields_vocab(fields, counters, opt.data_type, opt.share_vocab, opt.vocab_size_multiple, opt.src_vocab_size, opt.src_words_min_frequency) else: fields = existing_fields torch.save(fields, vocab_path)