Beispiel #1
0
def train_process_sent_producer(p_id, data_queue, word_count_actual, word2idx,
                                word_list, freq, args):
    if args.negative > 0:
        table_ptr_val = data_producer.init_unigram_table(
            word_list, freq, args.train_words)

    train_file = open(args.train)
    file_pos = args.file_size * p_id // args.processes
    train_file.seek(file_pos, 0)
    while True:
        try:
            train_file.read(1)
        except UnicodeDecodeError:
            file_pos -= 1
            train_file.seek(file_pos, 0)
        else:
            train_file.seek(file_pos, 0)
            break

    batch_count = 0
    if args.cbow == 1:
        batch_placeholder = np.zeros(
            (args.batch_size, 2 * args.window + 2 + 2 * args.negative),
            'int64')
    else:
        batch_placeholder = np.zeros((args.batch_size, 2 + 2 * args.negative),
                                     'int64')

    for it in range(args.iter):
        train_file.seek(file_pos, 0)

        last_word_cnt = 0
        word_cnt = 0
        sentence = []
        prev = ''
        eof = False
        while True:
            if eof or train_file.tell(
            ) > file_pos + args.file_size / args.processes:
                break

            while True:
                s = train_file.read(1)
                if not s:
                    eof = True
                    break
                elif s == ' ' or s == '\t':
                    if prev in word2idx:
                        sentence.append(prev)
                    prev = ''
                    if len(sentence) >= MAX_SENT_LEN:
                        break
                elif s == '\n':
                    if prev in word2idx:
                        sentence.append(prev)
                    prev = ''
                    break
                else:
                    prev += s

            if len(sentence) > 0:
                # subsampling
                sent_id = []
                if args.sample != 0:
                    sent_len = len(sentence)
                    i = 0
                    while i < sent_len:
                        word = sentence[i]
                        f = freq[word] / args.train_words
                        pb = (np.sqrt(f / args.sample) + 1) * args.sample / f

                        if pb > np.random.random_sample():
                            sent_id.append(word2idx[word])
                        i += 1

                if len(sent_id) < 2:
                    word_cnt += len(sentence)
                    sentence.clear()
                    continue

                next_random = (2**24) * np.random.randint(
                    0, 2**24) + np.random.randint(0, 2**24)
                if args.cbow == 1:  # train CBOW
                    chunk = data_producer.cbow_producer(
                        sent_id, len(sent_id), table_ptr_val, args.window,
                        args.negative, args.vocab_size, args.batch_size,
                        next_random)
                elif args.cbow == 0:  # train skipgram
                    chunk = data_producer.sg_producer(
                        sent_id, len(sent_id), table_ptr_val, args.window,
                        args.negative, args.vocab_size, args.batch_size,
                        next_random)

                chunk_pos = 0
                while chunk_pos < chunk.shape[0]:
                    remain_space = args.batch_size - batch_count
                    remain_chunk = chunk.shape[0] - chunk_pos

                    if remain_chunk < remain_space:
                        take_from_chunk = remain_chunk
                    else:
                        take_from_chunk = remain_space

                    batch_placeholder[batch_count:batch_count +
                                      take_from_chunk, :] = chunk[
                                          chunk_pos:chunk_pos +
                                          take_from_chunk, :]
                    batch_count += take_from_chunk

                    if batch_count == args.batch_size:
                        data_queue.put(batch_placeholder)
                        batch_count = 0

                    chunk_pos += take_from_chunk

                word_cnt += len(sentence)
                if word_cnt - last_word_cnt > 10000:
                    with word_count_actual.get_lock():
                        word_count_actual.value += word_cnt - last_word_cnt
                    last_word_cnt = word_cnt
                sentence.clear()

        with word_count_actual.get_lock():
            word_count_actual.value += word_cnt - last_word_cnt

    if batch_count > 0:
        data_queue.put(batch_placeholder[:batch_count, :])
    data_queue.put(None)
def train_process_sent_producer(p_id, data_queue, word_count_actual, word2idx,
                                word_list, freq, args):
    if args.negative > 0:
        table_ptr_val = data_producer.init_unigram_table(
            word_list, freq, args.train_words)

    train_file = open(args.train)
    file_pos = args.file_size * p_id // args.processes
    train_file.seek(file_pos, 0)
    while True:
        try:
            train_file.read(1)
        except UnicodeDecodeError:
            file_pos -= 1
            train_file.seek(file_pos, 0)
        else:
            train_file.seek(file_pos, 0)
            break

    batch_count = 0
    if args.cbow == 1:
        batch_placeholder = np.zeros(
            (args.megabatch_size, 2 * args.window + 2 + 2 * args.negative),
            'int64')
    else:
        batch_placeholder = np.zeros(
            (args.megabatch_size, 2 + 2 * args.negative), 'int64')
    #mattrum_cnt = 0
    for it in range(args.iter):
        train_file.seek(file_pos, 0)

        last_word_cnt = 0
        word_cnt = 0
        sentence = []
        prev = ''
        eof = False
        while True:
            if eof or train_file.tell(
            ) > file_pos + args.file_size / args.processes:
                break

            while True:
                s = train_file.read(1)
                if not s:
                    eof = True
                    break
                elif s == ' ' or s == '\t':
                    if prev in word2idx:
                        sentence.append(prev)
                    prev = ''
                    if len(sentence) >= MAX_SENT_LEN:
                        break
                elif s == '\n':
                    if prev in word2idx:
                        sentence.append(prev)
                    prev = ''
                    break
                else:
                    prev += s

            if len(sentence) > 0:
                #print("Full sentence")
                #print(' '.join(sentence))
                # subsampling
                sent_id = []
                trimmed = []
                if args.sample != 0:
                    sent_len = len(sentence)
                    i = 0
                    while i < sent_len:
                        word = sentence[i]
                        f = freq[word] / args.train_words
                        pb = (np.sqrt(f / args.sample) + 1) * args.sample / f

                        if pb > np.random.random_sample():
                            sent_id.append(word2idx[word])
                            """ if word2idx[word] == 'மற்றும்' and mattrum_cnt % 1000 == 0:
                                print("Hit another 1000 mattrums")
                                mattrum_cnt += 1
                        else:
                            trimmed.append(word) """
                        i += 1

                if len(sent_id) < 2:
                    word_cnt += len(sentence)
                    sentence.clear()
                    continue

                #print("Killed words")
                #print(' '.join(trimmed))
                #print("Trimmed sentence")
                #print(' '.join([word_list[index] for index in sent_id]))

                next_random = (2**24) * np.random.randint(
                    0, 2**24) + np.random.randint(0, 2**24)
                if args.cbow == 1:  # train CBOW
                    chunk = data_producer.cbow_producer(
                        sent_id, len(sent_id), table_ptr_val, args.window,
                        args.negative, args.vocab_size, args.batch_size,
                        next_random)
                elif args.cbow == 0:  # train skipgram
                    chunk = data_producer.sg_producer(
                        sent_id, len(sent_id), table_ptr_val, args.window,
                        args.negative, args.vocab_size, args.batch_size,
                        next_random)

                #print("Data points")
                #print(chunk)

                chunk_pos = 0
                while chunk_pos < chunk.shape[0]:
                    remain_space = args.megabatch_size - batch_count
                    remain_chunk = chunk.shape[0] - chunk_pos

                    if remain_chunk < remain_space:
                        take_from_chunk = remain_chunk
                    else:
                        take_from_chunk = remain_space

                    batch_placeholder[batch_count:batch_count +
                                      take_from_chunk, :] = chunk[
                                          chunk_pos:chunk_pos +
                                          take_from_chunk, :]
                    batch_count += take_from_chunk

                    if batch_count == args.megabatch_size:
                        if args.shuffle:
                            p = torch.randperm(batch_count)
                            batch_placeholder = batch_placeholder[p]

                        start = 0
                        while start < batch_count:
                            data_queue.put(batch_placeholder[
                                start:min(start +
                                          args.batch_size, batch_count)])
                            start += args.batch_size
                        #print("Batch placeholder")
                        #print(batch_placeholder)
                        batch_count = 0

                    chunk_pos += take_from_chunk

                word_cnt += len(sentence)
                if word_cnt - last_word_cnt > 10000:
                    with word_count_actual.get_lock():
                        word_count_actual.value += word_cnt - last_word_cnt
                    last_word_cnt = word_cnt
                sentence.clear()

        with word_count_actual.get_lock():
            word_count_actual.value += word_cnt - last_word_cnt

    #print("Total occurrences of mattrum: " + str(mattrum_cnt))
    #print("Total non-occurrences of mattrum: " + str(non_mattrum_cnt))
    if batch_count > 0:
        if args.shuffle:
            p = torch.randperm(batch_count)
            batch_placeholder[:batch_count] = batch_placeholder[p]

        start = 0
        while start < batch_count:
            data_queue.put(
                batch_placeholder[start:min(start +
                                            args.batch_size, batch_count)])
            start += args.batch_size
        #print("Batch placeholder")
        #print(batch_placeholder)
        batch_count = 0
    data_queue.put(None)
Beispiel #3
0
    args = parser.parse_args()
    print("Starting training using file %s" % args.train)

    train_file = open(args.train)
    train_file.seek(0, 2)
    vars(args)['file_size'] = train_file.tell()

    word2idx, word_list, freq = build_vocab(args)
    model = init_net(args)

    if args.cuda:
        model.cuda()

    optimizer = optim.SGD(model.parameters(), lr=args.lr)
    if args.negative > 0:
        table_ptr_val = data_producer.init_unigram_table(
            word_list, freq, args.train_words)

    vars(args)['t_start'] = time.monotonic()

    train_file = open(args.train)
    train_file.seek(0, 0)

    word_count_actual = 0
    for it in range(args.iter):
        #print("iter: %d" % it)
        train_file.seek(0, 0)

        batch_count = 0
        batch_placeholder = np.zeros(
            (args.batch_size, 2 * args.window + 2 + 2 * args.negative),
            'int64')