def train(args): """A function that performs the "theanolm train" command. :type args: argparse.Namespace :param args: a collection of command line arguments """ numpy.random.seed(args.random_seed) log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): print("Invalid logging level requested:", args.log_level) sys.exit(1) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.debug: theano.config.compute_test_value = 'warn' print("Enabled computing test values for tensor variables.") print("Warning: GpuArray backend will fail random number generation!") else: theano.config.compute_test_value = 'off' theano.config.profile = args.profile theano.config.profile_memory = args.profile with h5py.File(args.model_path, 'a', driver='core') as state: vocabulary = _read_vocabulary(args, state) if args.num_noise_samples > vocabulary.num_classes(): print("Number of noise samples ({}) is larger than the number of " "classes. This doesn't make sense and would cause sampling " "to fail.".format(args.num_noise_samples)) sys.exit(1) num_training_files = len(args.training_set) if len(args.weights) > num_training_files: print("You specified more weights than training files.") sys.exit(1) weights = numpy.ones(num_training_files).astype(theano.config.floatX) for index, weight in enumerate(args.weights): weights[index] = weight training_options = { 'batch_size': args.batch_size, 'sequence_length': args.sequence_length, 'validation_frequency': args.validation_frequency, 'patience': args.patience, 'stopping_criterion': args.stopping_criterion, 'max_epochs': args.max_epochs, 'min_epochs': args.min_epochs, 'max_annealing_count': args.max_annealing_count } logging.debug("TRAINING OPTIONS") for option_name, option_value in training_options.items(): logging.debug("%s: %s", option_name, str(option_value)) optimization_options = { 'method': args.optimization_method, 'epsilon': args.numerical_stability_term, 'gradient_decay_rate': args.gradient_decay_rate, 'sqr_gradient_decay_rate': args.sqr_gradient_decay_rate, 'learning_rate': args.learning_rate, 'weights': weights, 'momentum': args.momentum, 'max_gradient_norm': args.gradient_normalization, 'cost_function': args.cost, 'num_noise_samples': args.num_noise_samples, 'noise_sharing': args.noise_sharing, 'exclude_unk': args.exclude_unk } logging.debug("OPTIMIZATION OPTIONS") for option_name, option_value in optimization_options.items(): if isinstance(option_value, list): value_str = ', '.join(str(x) for x in option_value) logging.debug("%s: [%s]", option_name, value_str) else: logging.debug("%s: %s", option_name, str(option_value)) if len(args.sampling) > len(args.training_set): print("You specified more sampling coefficients than training " "files.") sys.exit(1) print("Creating trainer.") sys.stdout.flush() trainer = Trainer(training_options, vocabulary, args.training_set, args.sampling) trainer.set_logging(args.log_interval) print("Building neural network.") sys.stdout.flush() if args.architecture == 'lstm300' or args.architecture == 'lstm1500': architecture = Architecture.from_package(args.architecture) else: with open(args.architecture, 'rt', encoding='utf-8') as arch_file: architecture = Architecture.from_description(arch_file) network = Network(architecture, vocabulary, trainer.class_prior_probs, args.noise_dampening, default_device=args.default_device, profile=args.profile) print("Compiling optimization function.") sys.stdout.flush() optimizer = create_optimizer(optimization_options, network, profile=args.profile) if args.print_graph: print("Cost function computation graph:") theano.printing.debugprint(optimizer.gradient_update_function) trainer.initialize(network, state, optimizer) # XXX Write the model instantly back to disk. Just adds word unigram # counts. This is a temporary hack. Remove at some point. trainer.get_state(state) state.flush() # XXX if args.validation_file is not None: print("Building text scorer for cross-validation.") sys.stdout.flush() scorer = TextScorer(network, use_shortlist=True, exclude_unk=args.exclude_unk, profile=args.profile) print("Validation text:", args.validation_file.name) validation_mmap = mmap.mmap(args.validation_file.fileno(), 0, prot=mmap.PROT_READ) validation_iter = \ LinearBatchIterator(validation_mmap, vocabulary, batch_size=args.batch_size, max_sequence_length=args.sequence_length, map_oos_to_unk=False) trainer.set_validation(validation_iter, scorer) else: print("Cross-validation will not be performed.") validation_iter = None print("Training neural network.") sys.stdout.flush() trainer.train() if 'layers' not in state.keys(): print("The model has not been trained. No cross-validations were " "performed or training did not improve the model.") elif validation_iter is not None: network.set_state(state) perplexity = scorer.compute_perplexity(validation_iter) print("Best validation set perplexity:", perplexity)
def train(args): """A function that performs the "theanolm train" command. :type args: argparse.Namespace :param args: a collection of command line arguments """ numpy.random.seed(args.random_seed) log_file = args.log_file log_level = getattr(logging, args.log_level.upper(), None) if not isinstance(log_level, int): print("Invalid logging level requested:", args.log_level) sys.exit(1) log_format = '%(asctime)s %(funcName)s: %(message)s' if args.log_file == '-': logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level) else: logging.basicConfig(filename=log_file, format=log_format, level=log_level) if args.debug: theano.config.compute_test_value = 'warn' logging.info("Enabled computing test values for tensor variables.") logging.warning("GpuArray backend will fail random number generation!") else: theano.config.compute_test_value = 'off' theano.config.profile = args.profile theano.config.profile_memory = args.profile with h5py.File(args.model_path, 'a', driver='core') as state: vocabulary = _read_vocabulary(args, state) if args.num_noise_samples > vocabulary.num_classes(): print("Number of noise samples ({}) is larger than the number of " "classes. This doesn't make sense and would cause unigram " "sampling to fail.".format(args.num_noise_samples)) sys.exit(1) num_training_files = len(args.training_set) if len(args.weights) > num_training_files: print("You specified more weights than training files.") sys.exit(1) weights = numpy.ones(num_training_files).astype(theano.config.floatX) for index, weight in enumerate(args.weights): weights[index] = weight if len(args.sampling) > num_training_files: print("You specified more sampling coefficients than training " "files.") sys.exit(1) training_options = { 'batch_size': args.batch_size, 'sequence_length': args.sequence_length, 'validation_frequency': args.validation_frequency, 'patience': args.patience, 'stopping_criterion': args.stopping_criterion, 'max_epochs': args.max_epochs, 'min_epochs': args.min_epochs, 'max_annealing_count': args.max_annealing_count } optimization_options = { 'method': args.optimization_method, 'epsilon': args.numerical_stability_term, 'gradient_decay_rate': args.gradient_decay_rate, 'sqr_gradient_decay_rate': args.sqr_gradient_decay_rate, 'learning_rate': args.learning_rate, 'weights': weights, 'momentum': args.momentum, 'max_gradient_norm': args.gradient_normalization, 'num_noise_samples': args.num_noise_samples, 'noise_sharing': args.noise_sharing, } log_options(training_options, optimization_options, args) logging.info("Creating trainer.") trainer = Trainer(training_options, vocabulary, args.training_set, args.sampling) trainer.set_logging(args.log_interval) logging.info("Building neural network.") if args.architecture == 'lstm300' or args.architecture == 'lstm1500': architecture = Architecture.from_package(args.architecture) else: with open(args.architecture, 'rt', encoding='utf-8') as arch_file: architecture = Architecture.from_description(arch_file) default_device = get_default_device(args.default_device) network = Network(architecture, vocabulary, trainer.class_prior_probs, default_device=default_device, profile=args.profile) network.set_sampling(args.noise_distribution, args.noise_dampening, args.noise_sharing) logging.info("Building optimizer.") exclude_id = vocabulary.word_to_id['<unk>'] if args.exclude_unk \ else None epsilon = args.numerical_stability_term if args.cost == 'cross-entropy': cost_function = CrossEntropyCost(network, exclude_id, args.l1_regularization, args.l2_regularization, epsilon) elif args.cost == 'nce': cost_function = NCECost(network, exclude_id, args.l1_regularization, args.l2_regularization, epsilon) else: assert args.cost == 'blackout' cost_function = BlackoutCost(network, exclude_id, args.l1_regularization, args.l2_regularization, epsilon) try: optimizer = create_optimizer(optimization_options, network, cost_function, profile=args.profile) except theano.gradient.DisconnectedInputError as e: print("Cannot train the neural network because some of the " "parameters are disconnected from the output. Make sure all " "the layers are correctly connected in the network " "architecture. The error message was: `{}´".format(e)) if args.print_graph: print("Cost function computation graph:") theano.printing.debugprint(optimizer.gradient_update_function) trainer.initialize(network, state, optimizer, args.load_and_train) if args.validation_file is not None: logging.info("Building text scorer for cross-validation.") scorer = TextScorer(network, use_shortlist=True, exclude_unk=args.exclude_unk, profile=args.profile) logging.info("Validation text: %s", args.validation_file.name) validation_mmap = mmap.mmap(args.validation_file.fileno(), 0, prot=mmap.PROT_READ) validation_iter = \ LinearBatchIterator(validation_mmap, vocabulary, batch_size=args.batch_size, max_sequence_length=args.sequence_length, map_oos_to_unk=False) trainer.set_validation(validation_iter, scorer) else: logging.info("Cross-validation will not be performed.") validation_iter = None logging.info("Training neural network.") trainer.train() if 'layers' not in state.keys(): print("The model has not been trained. No cross-validations were " "performed or training did not improve the model.") elif validation_iter is not None: network.set_state(state) perplexity = scorer.compute_perplexity(validation_iter) print("Best validation set perplexity:", perplexity)