def load_data(args, filenames, max_examples=-1, dataset_name='java', test_split=False): """Load examples from preprocessed file. One example per line, JSON encoded.""" with open(filenames['src']) as f: sources = [line.strip() for line in tqdm(f, total=count_file_lines(filenames['src']))] if filenames['tgt'] is not None: with open(filenames['tgt']) as f: targets = [line.strip() for line in tqdm(f, total=count_file_lines(filenames['tgt']))] else: targets = [None] * len(sources) if filenames['src_tag'] is not None: with open(filenames['src_tag']) as f: source_tags = [line.strip() for line in tqdm(f, total=count_file_lines(filenames['src_tag']))] else: source_tags = [None] * len(sources) if args.use_tree_relative_attn: #filenames["rel_matrix"] is not None: with open(filenames["rel_matrix"]) as f: rel_matrices = [json.loads(line) for line in tqdm(f, total=count_file_lines(filenames["rel_matrix"]))] else: rel_matrices = [None] * len(sources) print(len(sources), len(source_tags), len(targets), len(rel_matrices)) assert len(sources) == len(source_tags) == len(targets) == len(rel_matrices) examples = [] for src, src_tag, tgt, rel_matrix in tqdm(zip(sources, source_tags, targets, \ rel_matrices), total=len(sources)): if dataset_name in ['java', 'python']: _ex = process_examples(LANG_ID_MAP[DATA_LANG_MAP[dataset_name]], src, src_tag, tgt, rel_matrix, args.max_src_len, args.max_tgt_len, args.code_tag_type, uncase=args.uncase, test_split=test_split, split_tokens=args.sum_over_subtokens) if _ex is not None: examples.append(_ex) if max_examples != -1 and len(examples) > max_examples: break return examples
def load_data(args, filenames, max_examples=-1, dataset_name='java', test_split=False): """Load examples from preprocessed file. One example per line, JSON encoded.""" with open(filenames['src']) as f: sources = [ line.strip() for line in tqdm(f, total=count_file_lines(filenames['src'])) ] if filenames['tgt'] is not None: with open(filenames['tgt']) as f: targets = [ line.strip() for line in tqdm(f, total=count_file_lines(filenames['tgt'])) ] else: targets = [None] * len(sources) if filenames['src_tag'] is not None: with open(filenames['src_tag']) as f: source_tags = [ line.strip() for line in tqdm(f, total=count_file_lines(filenames['src_tag'])) ] else: source_tags = [None] * len(sources) assert len(sources) == len(source_tags) == len(targets) examples = [] for src, src_tag, tgt in tqdm(zip(sources, source_tags, targets), total=len(sources)): if dataset_name in ['java', 'python']: _ex = process_examples(LANG_ID_MAP[DATA_LANG_MAP[dataset_name]], src, src_tag, tgt, args.max_src_len, args.max_tgt_len, args.code_tag_type, uncase=args.uncase, test_split=test_split) if _ex is not None: examples.append(_ex) if max_examples != -1 and len(examples) > max_examples: break return examples
def index_embedding_words(embedding_file): """Put all the words in embedding_file into a set.""" words = set() with open(embedding_file) as f: for line in tqdm(f, total=count_file_lines(embedding_file)): w = Vocabulary.normalize(line.rstrip().split(' ')[0]) words.add(w) words.update([BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD]) return words
def load_embeddings(word_dict, words, embedding_file, emb_layer): """Load pretrained embeddings for a given list of words, if they exist. #TODO: update args Args: words: iterable of tokens. Only those that are indexed in the dictionary are kept. embedding_file: path to text file of embeddings, space separated. """ words = {w for w in words if w in word_dict} logger.info('Loading pre-trained embeddings for %d words from %s' % (len(words), embedding_file)) # When normalized, some words are duplicated. (Average the embeddings). vec_counts, embedding = {}, {} with open(embedding_file) as f: # Skip first line if of form count/dim. line = f.readline().rstrip().split(' ') if len(line) != 2: f.seek(0) duplicates = set() for line in tqdm(f, total=count_file_lines(embedding_file)): parsed = line.rstrip().split(' ') assert (len(parsed) == emb_layer.word_vec_size + 1) w = word_dict.normalize(parsed[0]) if w in words: vec = torch.Tensor([float(i) for i in parsed[1:]]) if w not in vec_counts: vec_counts[w] = 1 embedding[w] = vec else: duplicates.add(w) vec_counts[w] = vec_counts[w] + 1 embedding[w].add_(vec) if len(duplicates) > 0: logging.warning('WARN: Duplicate embedding found for %s' % ', '.join(duplicates)) for w, c in vec_counts.items(): embedding[w].div_(c) emb_layer.init_word_vectors(word_dict, embedding) logger.info('Loaded %d embeddings (%.2f%%)' % (len(vec_counts), 100 * len(vec_counts) / len(words)))