Пример #1
0
def load_data(args, filenames, max_examples=-1, dataset_name='java',
              test_split=False):
    """Load examples from preprocessed file. One example per line, JSON encoded."""

    with open(filenames['src']) as f:
        sources = [line.strip() for line in
                   tqdm(f, total=count_file_lines(filenames['src']))]

    if filenames['tgt'] is not None:
        with open(filenames['tgt']) as f:
            targets = [line.strip() for line in
                       tqdm(f, total=count_file_lines(filenames['tgt']))]
    else:
        targets = [None] * len(sources)

    if filenames['src_tag'] is not None:
        with open(filenames['src_tag']) as f:
            source_tags = [line.strip() for line in
                           tqdm(f, total=count_file_lines(filenames['src_tag']))]
    else:
        source_tags = [None] * len(sources)
        
    if args.use_tree_relative_attn: #filenames["rel_matrix"] is not None:
        with open(filenames["rel_matrix"]) as f:
            rel_matrices = [json.loads(line) for line in
                           tqdm(f, total=count_file_lines(filenames["rel_matrix"]))]
    else:
        rel_matrices = [None] * len(sources)
        
    print(len(sources), len(source_tags), len(targets), len(rel_matrices))
    assert len(sources) == len(source_tags) == len(targets) == len(rel_matrices)

    examples = []
    for src, src_tag, tgt, rel_matrix in tqdm(zip(sources, source_tags, targets, \
                                                  rel_matrices),
                                  total=len(sources)):
        if dataset_name in ['java', 'python']:
            _ex = process_examples(LANG_ID_MAP[DATA_LANG_MAP[dataset_name]],
                                   src,
                                   src_tag,
                                   tgt,
                                   rel_matrix,
                                   args.max_src_len,
                                   args.max_tgt_len,
                                   args.code_tag_type,
                                   uncase=args.uncase,
                                   test_split=test_split,
                                   split_tokens=args.sum_over_subtokens)
            if _ex is not None:
                examples.append(_ex)

        if max_examples != -1 and len(examples) > max_examples:
            break

    return examples
Пример #2
0
def load_data(args,
              filenames,
              max_examples=-1,
              dataset_name='java',
              test_split=False):
    """Load examples from preprocessed file. One example per line, JSON encoded."""

    with open(filenames['src']) as f:
        sources = [
            line.strip()
            for line in tqdm(f, total=count_file_lines(filenames['src']))
        ]

    if filenames['tgt'] is not None:
        with open(filenames['tgt']) as f:
            targets = [
                line.strip()
                for line in tqdm(f, total=count_file_lines(filenames['tgt']))
            ]
    else:
        targets = [None] * len(sources)

    if filenames['src_tag'] is not None:
        with open(filenames['src_tag']) as f:
            source_tags = [
                line.strip()
                for line in tqdm(f,
                                 total=count_file_lines(filenames['src_tag']))
            ]
    else:
        source_tags = [None] * len(sources)

    assert len(sources) == len(source_tags) == len(targets)

    examples = []
    for src, src_tag, tgt in tqdm(zip(sources, source_tags, targets),
                                  total=len(sources)):
        if dataset_name in ['java', 'python']:
            _ex = process_examples(LANG_ID_MAP[DATA_LANG_MAP[dataset_name]],
                                   src,
                                   src_tag,
                                   tgt,
                                   args.max_src_len,
                                   args.max_tgt_len,
                                   args.code_tag_type,
                                   uncase=args.uncase,
                                   test_split=test_split)
            if _ex is not None:
                examples.append(_ex)

        if max_examples != -1 and len(examples) > max_examples:
            break

    return examples
Пример #3
0
def index_embedding_words(embedding_file):
    """Put all the words in embedding_file into a set."""
    words = set()
    with open(embedding_file) as f:
        for line in tqdm(f, total=count_file_lines(embedding_file)):
            w = Vocabulary.normalize(line.rstrip().split(' ')[0])
            words.add(w)

    words.update([BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD])
    return words
Пример #4
0
    def load_embeddings(word_dict, words, embedding_file, emb_layer):
        """Load pretrained embeddings for a given list of words, if they exist.
        #TODO: update args
        Args:
            words: iterable of tokens. Only those that are indexed in the
              dictionary are kept.
            embedding_file: path to text file of embeddings, space separated.
        """
        words = {w for w in words if w in word_dict}
        logger.info('Loading pre-trained embeddings for %d words from %s' %
                    (len(words), embedding_file))

        # When normalized, some words are duplicated. (Average the embeddings).
        vec_counts, embedding = {}, {}
        with open(embedding_file) as f:
            # Skip first line if of form count/dim.
            line = f.readline().rstrip().split(' ')
            if len(line) != 2:
                f.seek(0)

            duplicates = set()
            for line in tqdm(f, total=count_file_lines(embedding_file)):
                parsed = line.rstrip().split(' ')
                assert (len(parsed) == emb_layer.word_vec_size + 1)
                w = word_dict.normalize(parsed[0])
                if w in words:
                    vec = torch.Tensor([float(i) for i in parsed[1:]])
                    if w not in vec_counts:
                        vec_counts[w] = 1
                        embedding[w] = vec
                    else:
                        duplicates.add(w)
                        vec_counts[w] = vec_counts[w] + 1
                        embedding[w].add_(vec)

            if len(duplicates) > 0:
                logging.warning('WARN: Duplicate embedding found for %s' %
                                ', '.join(duplicates))

        for w, c in vec_counts.items():
            embedding[w].div_(c)

        emb_layer.init_word_vectors(word_dict, embedding)
        logger.info('Loaded %d embeddings (%.2f%%)' %
                    (len(vec_counts), 100 * len(vec_counts) / len(words)))