コード例 #1
0
ファイル: augmenter.py プロジェクト: nusnlp/corelm
def augment(model_path, input_nbest_path, vocab_path, output_nbest_path):
	classifier = MLP(model_path=model_path)
	evaluator = eval.Evaluator(None, classifier)

	vocab = VocabManager(vocab_path)

	ngram_size = classifier.ngram_size

	def get_ngrams(tokens):
		for i in range(ngram_size - 1):
			tokens.insert(0, '<s>')
		if vocab.has_end_padding:
			tokens.append('</s>')
		indices = vocab.get_ids_given_word_list(tokens)
		return U.get_all_windows(indices, ngram_size)

	input_nbest = NBestList(input_nbest_path, mode='r')
	output_nbest = NBestList(output_nbest_path, mode='w')

	L.info('Augmenting: ' + input_nbest_path)
	
	start_time = time.time()

	counter = 0
	cache = dict()
	for group in input_nbest:
		ngram_list = []
		for item in group:
			tokens = item.hyp.split()
			ngrams = get_ngrams(tokens)
			for ngram in ngrams:
				if not cache.has_key(str(ngram)):
					ngram_list.append(ngram)
					cache[str(ngram)] = 1000
		if len(ngram_list) > 0:
			ngram_array = np.asarray(ngram_list, dtype='int32')
			ngram_log_prob_list = evaluator.get_ngram_log_prob(ngram_array[:,0:-1], ngram_array[:,-1])
			for i in range(len(ngram_list)):
				cache[str(ngram_list[i])] = ngram_log_prob_list[i]
		for item in group:
			tokens = item.hyp.split()
			ngrams = get_ngrams(tokens)
			sum_ngram_log_prob = 0
			for ngram in ngrams:
				sum_ngram_log_prob += cache[str(ngram)]
			item.append_feature(sum_ngram_log_prob)
			output_nbest.write(item)
		#print counter
		counter += 1
	output_nbest.close()

	L.info("Ran for %.2fs" % (time.time() - start_time))
コード例 #2
0
ファイル: textReader.py プロジェクト: wanghm92/corelm_sll
	def __init__(self, dataset_path, is_nbest, ngram_size, vocab_path):
		
		L.info("Initializing dataset from: " + dataset_path)
		
		vocab = VocabManager(vocab_path)
		
		def get_ngrams(tokens):
			for i in range(ngram_size - 1):
				tokens.insert(0, '<s>')
			if vocab.has_end_padding:
				tokens.append('</s>')
			indices = vocab.get_ids_given_word_list(tokens)
			return U.get_all_windows(indices, ngram_size)
		
		starts_list = []
		curr_index = 0
		curr_start_index = 0
		self.num_sentences = 0
		
		ngrams_list = []
		if is_nbest == True:
			nbest = NBestList(dataset_path)
			for group in nbest:
				for item in group:
					tokens = item.hyp.split()
					starts_list.append(curr_start_index)
					ngrams = get_ngrams(tokens)
					ngrams_list += ngrams
					curr_start_index += len(ngrams)
		else:
			dataset = codecs.open(dataset_path, 'r', encoding="UTF-8")
			for line in dataset:
				tokens = line.split()
				starts_list.append(curr_start_index)
				ngrams = get_ngrams(tokens)
				ngrams_list += ngrams
				curr_start_index += len(ngrams)
			dataset.close()
		
		self.num_sentences = len(starts_list)
		
		data = np.asarray(ngrams_list)
		starts_list.append(curr_start_index)
		starts_array = np.asarray(starts_list)
		
		x = data[:,0:-1]
		y = data[:,-1]
		
		self.num_samples = y.shape[0]
		
		self.shared_starts = T.cast(theano.shared(starts_array, borrow=True), 'int64')
		self.shared_x = T.cast(theano.shared(x, borrow=True), 'int32')
		self.shared_y = T.cast(theano.shared(y, borrow=True), 'int32')
コード例 #3
0
def augment(model_path, input_nbest_path, vocab_path, output_nbest_path):
    classifier = MLP(model_path=model_path)
    evaluator = eval.Evaluator(None, classifier)

    vocab = VocabManager(vocab_path)

    ngram_size = classifier.ngram_size

    def get_ngrams(tokens):
        for i in range(ngram_size - 1):
            tokens.insert(0, '<s>')
        if vocab.has_end_padding:
            tokens.append('</s>')
        indices = vocab.get_ids_given_word_list(tokens)
        return U.get_all_windows(indices, ngram_size)

    input_nbest = NBestList(input_nbest_path, mode='r')
    output_nbest = NBestList(output_nbest_path, mode='w')

    L.info('Augmenting: ' + input_nbest_path)

    start_time = time.time()

    counter = 0
    cache = dict()
    for group in input_nbest:
        ngram_list = []
        for item in group:
            tokens = item.hyp.split()
            ngrams = get_ngrams(tokens)
            for ngram in ngrams:
                if not cache.has_key(str(ngram)):
                    ngram_list.append(ngram)
                    cache[str(ngram)] = 1000
        if len(ngram_list) > 0:
            ngram_array = np.asarray(ngram_list, dtype='int32')
            ngram_log_prob_list = evaluator.get_ngram_log_prob(
                ngram_array[:, 0:-1], ngram_array[:, -1])
            for i in range(len(ngram_list)):
                cache[str(ngram_list[i])] = ngram_log_prob_list[i]
        for item in group:
            tokens = item.hyp.split()
            ngrams = get_ngrams(tokens)
            sum_ngram_log_prob = 0
            for ngram in ngrams:
                sum_ngram_log_prob += cache[str(ngram)]
            item.append_feature(sum_ngram_log_prob)
            output_nbest.write(item)
        #print counter
        counter += 1
    output_nbest.close()

    L.info("Ran for %.2fs" % (time.time() - start_time))
コード例 #4
0
if args.quiet:
    L.quiet = True

methods = {
    'none': B.no_smoothing,
    'epsilon': B.add_epsilon_smoothing,
    'lin': B.lin_smoothing,
    'nist': B.nist_smoothing,
    'chen': B.chen_smoothing
}

ref_path_list = args.ref_paths.split(',')

input_nbest = NBestList(args.input_path,
                        mode='r',
                        reference_list=ref_path_list)
if args.out_nbest_path:
    output_nbest = NBestList(args.out_nbest_path, mode='w')
if args.out_scores_path:
    output_scores = open(args.out_scores_path, mode='w')
output_1best = codecs.open(args.out_1best_path, mode='w', encoding='UTF-8')

U.xassert(methods.has_key(args.method),
          "Invalid smoothing method: " + args.method)
scorer = methods[args.method]

L.info('Processing the n-best list')


def process_group(group):
コード例 #5
0
ファイル: rerank.py プロジェクト: wanghm92/corelm_sll
if args.no_aug:
    shutil.copy(args.input_nbest, output_nbest_path)
else:
    augmenter.augment(args.model_path, args.input_nbest, args.vocab_path,
                      output_nbest_path)

with open(args.weights, 'r') as input_weights:
    lines = input_weights.readlines()
    if len(lines) > 1:
        L.warning(
            "Weights file has more than one line. I'll read the 1st and ignore the rest."
        )
    weights = np.asarray(lines[0].strip().split(" "), dtype=float)

prefix = os.path.basename(args.input_nbest)
input_aug_nbest = NBestList(output_nbest_path, mode='r')
output_nbest = NBestList(args.out_dir + '/' + prefix + '.reranked.nbest',
                         mode='w')
output_1best = codecs.open(args.out_dir + '/' + prefix + '.reranked.1best',
                           mode='w',
                           encoding='UTF-8')


def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

コード例 #6
0
ファイル: oracle.py プロジェクト: tamhd/corelm
args = parser.parse_args()

if args.quiet:
	L.quiet = True

methods = {
	'none'    : B.no_smoothing,
	'epsilon' : B.add_epsilon_smoothing,
	'lin'     : B.lin_smoothing,
	'nist'    : B.nist_smoothing,
	'chen'    : B.chen_smoothing
}

ref_path_list = args.ref_paths.split(',')

input_nbest = NBestList(args.input_path, mode='r', reference_list=ref_path_list)
if args.out_nbest_path:
	output_nbest = NBestList(args.out_nbest_path, mode='w')
if args.out_scores_path:
	output_scores = open(args.out_scores_path, mode='w')
output_1best = codecs.open(args.out_1best_path, mode='w', encoding='UTF-8')

U.xassert(methods.has_key(args.method), "Invalid smoothing method: " + args.method)
scorer = methods[args.method]

L.info('Processing the n-best list')

def process_group(group):
	index = 0
	scores = dict()
	for item in group:
コード例 #7
0
parser.add_argument("-v",
                    "--vocab-file",
                    dest="vocab_path",
                    help="The vocabulary file.")
parser.add_argument("-m",
                    "--model-file",
                    dest="model_path",
                    help="Input CoreLM model file")
parser.add_argument("-d",
                    "--device",
                    dest="device",
                    default="gpu",
                    help="The computing device (cpu or gpu)")
args = parser.parse_args()

input_nbest = NBestList(args.input_path, mode='r')

mode = -1

if args.command.startswith('top'):
    mode = 0
    N = int(args.command[3:])  # N in N-best
    output_nbest = NBestList(args.output_path, mode='w')
elif args.command == '1best':
    mode = 1
    output_1best = codecs.open(args.output_path, mode='w', encoding='UTF-8')
elif args.command.startswith('feature'):
    mode = 2
    N = int(args.command[7:])  # Nth feature
    output = open(args.output_path, mode='w')
elif args.command.startswith('correl'):
コード例 #8
0
ファイル: tools.py プロジェクト: tamhd/corelm
parser.add_argument("-i", "--input-file", dest="input_path", required=True, help="Input n-best file")
parser.add_argument("-s", "--input-scores", dest="oracle", help="Input oracle scores  the n-best file")
parser.add_argument("-o", "--output-file", dest="output_path", required=True, help="Output file")
parser.add_argument("-v", "--vocab-file", dest="vocab_path", help="The vocabulary file.")
parser.add_argument("-m", "--model-file", dest="model_path",  help="Input PrimeLM model file")
parser.add_argument("-d", "--device", dest="device", default="gpu", help="The computing device (cpu or gpu)")
args = parser.parse_args()

input_nbest = NBestList(args.input_path, mode='r')

mode = -1

if args.command.startswith('top'):
	mode = 0
	N = int(args.command[3:]) # N in N-best
	output_nbest = NBestList(args.output_path, mode='w')
elif args.command == '1best':
	mode = 1
	output_1best = codecs.open(args.output_path, mode='w', encoding='UTF-8')
elif args.command.startswith('feature'):
	mode = 2
	N = int(args.command[7:]) # Nth feature
	output = open(args.output_path, mode='w')
elif args.command.startswith('correl'):
	mode = 3
	N = int(args.command[6:]) # Nth feature
	U.xassert(args.oracle, "correlN command needs a file (-s) containing oracle scores")
	with open(args.oracle, mode='r') as oracles_file:
		oracles = map(float, oracles_file.read().splitlines())
	#output = open(args.output_path, mode='w')
elif args.command.startswith('augment'):