def test_score_sequence(self): # Network predicts <unk> probability. scorer = TextScorer(self.dummy_network) word_ids = numpy.arange(6) class_ids = numpy.arange(6) membership_probs = numpy.ones(6, dtype='float32') logprob = scorer.score_sequence(word_ids, class_ids, membership_probs) correct = word_ids[1:].astype('float32') correct = correct / 5 correct = numpy.log(correct).sum() self.assertAlmostEqual(logprob, correct, places=5) # <unk> is removed from the resulting logprobs. scorer = TextScorer(self.dummy_network, ignore_unk=True) word_ids = numpy.arange(6) word_ids[3] = self.vocabulary.word_to_id['<unk>'] class_ids = numpy.arange(6) membership_probs = numpy.ones(6, dtype='float32') logprob = scorer.score_sequence(word_ids, class_ids, membership_probs) correct = word_ids[[1, 2, 4, 5]].astype('float32') correct = correct / 5 correct = numpy.log(correct).sum() self.assertAlmostEqual(logprob, correct, places=5) # <unk> is assigned a constant logprob. scorer = TextScorer(self.dummy_network, ignore_unk=False, unk_penalty=-5) word_ids = numpy.arange(6) word_ids[3] = self.vocabulary.word_to_id['<unk>'] class_ids = numpy.arange(6) membership_probs = numpy.ones(6, dtype='float32') logprob = scorer.score_sequence(word_ids, class_ids, membership_probs) correct = word_ids[[1, 2, 4, 5]].astype('float32') correct = correct / 5 correct = numpy.log(correct).sum() - 5 self.assertAlmostEqual(logprob, correct, places=5)
def test_score_batch(self): # Network predicts <unk> probability. scorer = TextScorer(self.dummy_network) word_ids = numpy.arange(6).reshape((3, 2)) class_ids = numpy.arange(6).reshape((3, 2)) membership_probs = numpy.ones_like(word_ids).astype('float32') mask = numpy.ones_like(word_ids) logprobs = scorer.score_batch(word_ids, class_ids, membership_probs, mask) assert_almost_equal(logprobs[0], numpy.log(word_ids[1:,0].astype('float32') / 5)) assert_almost_equal(logprobs[1], numpy.log(word_ids[1:,1].astype('float32') / 5)) # <unk> is removed from the resulting logprobs. scorer = TextScorer(self.dummy_network, ignore_unk=True) word_ids = numpy.arange(6).reshape((3, 2)) word_ids[1,1] = self.vocabulary.word_to_id['<unk>'] class_ids = numpy.arange(6).reshape((3, 2)) membership_probs = numpy.ones_like(word_ids).astype('float32') mask = numpy.ones_like(word_ids) logprobs = scorer.score_batch(word_ids, class_ids, membership_probs, mask) assert_almost_equal(logprobs[0], numpy.log(word_ids[1:,0].astype('float32') / 5)) assert_almost_equal(logprobs[1], numpy.log(word_ids[2:,1].astype('float32') / 5)) # <unk> is assigned a constant logprob. scorer = TextScorer(self.dummy_network, ignore_unk=False, unk_penalty=-5) word_ids = numpy.arange(6).reshape((3, 2)) word_ids[1,1] = self.vocabulary.word_to_id['<unk>'] class_ids = numpy.arange(6).reshape((3, 2)) membership_probs = numpy.ones_like(word_ids).astype('float32') mask = numpy.ones_like(word_ids) logprobs = scorer.score_batch(word_ids, class_ids, membership_probs, mask) assert_almost_equal(logprobs[0], numpy.log(word_ids[1:,0].astype('float32') / 5)) assert_almost_equal(logprobs[1][0], -5) assert_almost_equal(logprobs[1][1], numpy.log(word_ids[2,1].astype('float32') / 5))
def train(args): numpy.random.seed(args.random_seed) log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): raise ValueError("Invalid logging level requested: " + args.log_level) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.debug: theano.config.compute_test_value = 'warn' else: theano.config.compute_test_value = 'off' theano.config.profile = args.profile theano.config.profile_memory = args.profile with h5py.File(args.model_path, 'a', driver='core') as state: if state.keys(): print("Reading vocabulary from existing network state.") sys.stdout.flush() vocabulary = Vocabulary.from_state(state) elif args.vocabulary is None: print("Constructing vocabulary from training set.") sys.stdout.flush() vocabulary = Vocabulary.from_corpus(args.training_set, args.num_classes) for training_file in args.training_set: training_file.seek(0) vocabulary.get_state(state) else: print("Reading vocabulary from {}.".format(args.vocabulary)) sys.stdout.flush() with open(args.vocabulary, 'rt', encoding='utf-8') as vocab_file: vocabulary = Vocabulary.from_file(vocab_file, args.vocabulary_format) vocabulary.get_state(state) print("Number of words in vocabulary:", vocabulary.num_words()) print("Number of word classes:", vocabulary.num_classes()) print("Building neural network.") sys.stdout.flush() if args.architecture == 'lstm300' or args.architecture == 'lstm1500': architecture = Architecture.from_package(args.architecture) else: with open(args.architecture, 'rt', encoding='utf-8') as arch_file: architecture = Architecture.from_description(arch_file) network = Network(vocabulary, architecture, batch_processing=True, profile=args.profile) sys.stdout.flush() if args.unk_penalty is None: ignore_unk = False unk_penalty = None elif args.unk_penalty == 0: ignore_unk = True unk_penalty = None else: ignore_unk = False unk_penalty = args.unk_penalty num_training_files = len(args.training_set) if len(args.weights) > num_training_files: print("You specified more weights than training files.") sys.exit(1) weights = numpy.ones(num_training_files).astype(theano.config.floatX) for index, weight in enumerate(args.weights): weights[index] = weight print("Building text scorer.") scorer = TextScorer(network, ignore_unk, unk_penalty, args.profile) validation_mmap = mmap.mmap(args.validation_file.fileno(), 0, prot=mmap.PROT_READ) validation_iter = LinearBatchIterator(validation_mmap, vocabulary, batch_size=32) optimization_options = { 'method': args.optimization_method, 'epsilon': args.numerical_stability_term, 'gradient_decay_rate': args.gradient_decay_rate, 'sqr_gradient_decay_rate': args.sqr_gradient_decay_rate, 'learning_rate': args.learning_rate, 'weights': weights, 'momentum': args.momentum, 'ignore_unk': ignore_unk, 'unk_penalty': unk_penalty } if not args.gradient_normalization is None: optimization_options['max_gradient_norm'] = args.gradient_normalization logging.debug("OPTIMIZATION OPTIONS") for option_name, option_value in optimization_options.items(): if type(option_value) is list: value_str = ', '.join(str(x) for x in option_value) logging.debug("%s: [%s]", option_name, value_str) else: logging.debug("%s: %s", option_name, str(option_value)) training_options = { 'strategy': args.training_strategy, 'batch_size': args.batch_size, 'sequence_length': args.sequence_length, 'validation_frequency': args.validation_frequency, 'patience': args.patience, 'stopping_criterion': args.stopping_criterion, 'max_epochs': args.max_epochs, 'min_epochs': args.min_epochs, 'max_annealing_count': args.max_annealing_count } logging.debug("TRAINING OPTIONS") for option_name, option_value in training_options.items(): logging.debug("%s: %s", option_name, str(option_value)) print("Building neural network trainer.") sys.stdout.flush() if len(args.sampling) > len(args.training_set): print("You specified more sampling coefficients than training " "files.") sys.exit(1) trainer = create_trainer( training_options, optimization_options, network, vocabulary, scorer, args.training_set, args.sampling, validation_iter, state, args.profile) trainer.set_logging(args.log_interval) print("Training neural network.") sys.stdout.flush() trainer.run() if not state.keys(): print("The model has not been trained.") else: network.set_state(state) perplexity = scorer.compute_perplexity(validation_iter) print("Best validation set perplexity:", perplexity)