Exemplo n.º 1
0
def load_data(config, dirname='../dataset/', max_sample_size=None):

    samples = []
    skipped = 0

    input_vocab = Counter()
    gender_vocab = Counter()

    #########################################################
    # Read names
    #########################################################
    def read_data(filename='names.csv'):
        data = open(filename).readlines()
        samples = []
        for datum in data:
            name = datum.split(',')[1]
            name = ''.join(name.split())
            samples.append(remove_punct_symbols(name))

        return samples

    def read_dirs(dirs=['boy', 'girl']):
        samples = []
        for d in dirs:
            for filename in os.listdir('{}/{}'.format(dirname, d)):
                s = read_data('{}/{}/{}'.format(dirname, d, filename))
                s = [(d, n) for n in s]
                samples.extend(s)

        return list(set(samples))

    raw_samples = read_dirs()
    log.info('read {} names'.format(len(raw_samples)))

    #########################################################
    # Read tamil words
    #########################################################
    def read_words(filename=config.HPCONFIG.lm_dataset_path):
        samples = []
        for line in tqdm(
                open(filename).readlines()[:config.HPCONFIG.lm_samples_count],
                'reading lm file for words'):
            s = line.split()
            s = [('neutral', n) for n in s]
            samples.extend(s)

        return list(set(samples))

    pretrain_samples = read_words()

    #########################################################
    # build vocab
    #########################################################
    all_samples = raw_samples + pretrain_samples
    log.info('building input_vocabulary...')

    for gender, name in tqdm(all_samples, desc='building vocab'):
        name = remove_punct_symbols(name)
        name = tamil.utf8.get_letters(name.strip())
        if len(name):
            input_vocab.update(name)
            gender_vocab.update([gender])

    vocab = Vocab(input_vocab, special_tokens=VOCAB, freq_threshold=50)

    print(gender_vocab)
    gender_vocab = Vocab(gender_vocab, special_tokens=[])

    if config.CONFIG.write_vocab_to_file:
        vocab.write_to_file(config.ROOT_DIR + '/input_vocab.csv')
        gender_vocab.write_to_file(config.ROOT_DIR + '/gender_vocab.csv')

    def build_samples(raw_samples):
        samples = []
        for i, (gender,
                name) in enumerate(tqdm(raw_samples, desc='processing names')):
            try:

                #name = remove_punct_symbols(name)
                name = tamil.utf8.get_letters(name.strip())

                if len(name) < 2:
                    continue

                log.debug('===')
                log.debug(pformat(name))

                for a, b in zip(range(len(name)), range(1, len(name) - 1)):
                    template = list(NULL_CHAR * len(name))
                    template[a] = name[a]
                    template[b] = name[b]
                    samples.append(
                        Sample('{}.{}'.format(gender, i), gender, template,
                               name))

                if max_sample_size and len(samples) > max_sample_size:
                    break

            except:
                skipped += 1
                log.exception('{}'.format(name))

        return samples

    pretrain_samples = build_samples(pretrain_samples)
    samples = build_samples(raw_samples)
    print('skipped {} samples'.format(skipped))

    pivot = int(len(samples) * config.CONFIG.split_ratio)
    train_samples, test_samples = samples[:pivot], samples[pivot:]
    #train_samples, test_samples = samples, []

    #train_samples = sorted(train_samples, key=lambda x: len(x.sequence), reverse=True)

    return NameDataset('names', (train_samples, test_samples),
                       pretrain_samples=pretrain_samples,
                       input_vocab=vocab,
                       gender_vocab=gender_vocab)
Exemplo n.º 2
0
def load_data(config,
              filename='../dataset/lm_lengthsorted.txt',
              max_sample_size=None):
    
    samples = []
    skipped = 0

    input_vocab = Counter()
    output_vocab = Counter()
    bloom_filter = Counter()
    try:
        log.info('processing file: {}'.format(filename))
        text_file = open(filename).readlines()
        
        log.info('building input_vocabulary...')
        sentences = set()
        for i, l in tqdm(enumerate(text_file[:config.HPCONFIG.max_samples]),
                            desc='processing {}'.format(filename)):

            sentence = remove_punct_symbols(l)
            sentence = sentence.strip().split()
            if len(sentence):
                input_vocab.update(sentence)
                sentences.add(tuple(sentence))

                
        freq_threshold = (config.HPCONFIG.freq_threshold * (float(config.HPCONFIG.max_samples)
                                                            /len(text_file)))
        log.info('freq_threhold: {}'.format(freq_threshold))
        vocab = Vocab(input_vocab,
                      special_tokens = VOCAB,
                      freq_threshold = int(freq_threshold))

        if config.CONFIG.write_vocab_to_file:
            vocab.write_to_file(config.ROOT_DIR + '/vocab.csv')
        
        for i, sentence in tqdm(enumerate(sentences),
                         desc='processing sentences'):

            if len(sentence) < 2:
                continue
            
            unk_ratio = float(count_UNKS(sentence, vocab))/len(sentence)

            log.debug('===')
            log.debug(pformat(sentence))
            
            sentence =  [i if vocab[i] != vocab['UNK'] else 'UNK' for i in sentence ]
            log.debug(pformat(sentence))

            if unk_ratio > 0.7:
                log.debug('unk ratio is heavy: {}'.format(unk_ratio))
                continue
                
            for center_word_pos, center_word in enumerate(sentence):
                for w in range(-config.HPCONFIG.window_size,
                                config.HPCONFIG.window_size + 1):
                    context_word_pos = center_word_pos + w
                    # make soure not jump out sentence
                    if (context_word_pos < 0
                        or context_word_pos >= len(sentence)
                        or center_word_pos == context_word_pos):
                        continue

                    pair = (center_word, sentence[context_word_pos])
                    if pair[0] != 'UNK' and pair[1] != 'UNK':
                        if not pair in bloom_filter:
                            pass
                            samples.append(
                                Sample('{}.{}'.format(i, center_word_pos),
                                       #sentence,
                                       center_word,
                                       sentence[context_word_pos]
                                )
                            )
                        bloom_filter.update([pair])
                        
            if  max_sample_size and len(samples) > max_sample_size:
                break

    except:
        skipped += 1
        log.exception('{}'.format(l))

    print('skipped {} samples'.format(skipped))

    if config.CONFIG.dump_bloom_filter:
        with open('word_pair.csv', 'w') as F:
            for k,v in bloom_filter.items():
                F.write('|'.join(list(k) + [str(v)]) + '\n')
                    
    #pivot = int(len(samples) * config.CONFIG.split_ratio)
    #train_samples, test_samples = samples[:pivot], samples[pivot:]
    train_samples, test_samples = samples, []

    return Dataset(filename,
                   (train_samples, test_samples),
                   input_vocab = vocab,
                   output_vocab = vocab)