Example #1
0
def score(args):
    """A function that performs the "theanolm score" command.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments
    """

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        print("Invalid logging level requested:", args.log_level)
        sys.exit(1)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout,
                            format=log_format,
                            level=log_level)
    else:
        logging.basicConfig(filename=log_file,
                            format=log_format,
                            level=log_level)

    if args.debug:
        theano.config.compute_test_value = 'warn'
        logging.info("Enabled computing test values for tensor variables.")
        logging.warning("GpuArray backend will fail random number generation!")
    else:
        theano.config.compute_test_value = 'off'
    theano.config.profile = args.profile
    theano.config.profile_memory = args.profile

    default_device = get_default_device(args.default_device)
    network = Network.from_file(args.model_path,
                                exclude_unk=args.exclude_unk,
                                default_device=default_device)

    logging.info("Building text scorer.")
    scorer = TextScorer(network, args.shortlist, args.exclude_unk,
                        args.profile)

    logging.info("Scoring text.")
    if args.output == 'perplexity':
        _score_text(args.input_file, network.vocabulary, scorer,
                    args.output_file, args.log_base, args.subwords, False)
    elif args.output == 'word-scores':
        _score_text(args.input_file, network.vocabulary, scorer,
                    args.output_file, args.log_base, args.subwords, True)
    elif args.output == 'utterance-scores':
        _score_utterances(args.input_file, network.vocabulary, scorer,
                          args.output_file, args.log_base)
    else:
        print("Invalid output format requested:", args.output)
        sys.exit(1)
Example #2
0
def sample(args):
    """A function that performs the "theanolm sample" command.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments
    """

    numpy.random.seed(args.random_seed)

    if args.debug:
        theano.config.compute_test_value = 'warn'
    else:
        theano.config.compute_test_value = 'off'

    with h5py.File(args.model_path, 'r') as state:
        logging.info("Reading vocabulary from network state.")
        vocabulary = Vocabulary.from_state(state)
        logging.info("Number of words in vocabulary: %d",
                     vocabulary.num_words())
        logging.info("Number of words in shortlist: %d",
                     vocabulary.num_shortlist_words())
        logging.info("Number of word classes: %d", vocabulary.num_classes())
        logging.info("Building neural network.")
        architecture = Architecture.from_state(state)
        default_device = get_default_device(args.default_device)
        network = Network(architecture,
                          vocabulary,
                          mode=Network.Mode(minibatch=False),
                          default_device=default_device)
        logging.info("Restoring neural network state.")
        network.set_state(state)

    logging.info("Building text sampler.")
    sampler = TextSampler(network)

    sequences = sampler.generate(args.sentence_length,
                                 args.num_sentences,
                                 seed_sequence=args.seed_sequence)
    for sequence in sequences:
        try:
            eos_pos = sequence.index('</s>')
            sequence = sequence[:eos_pos + 1]
        except ValueError:
            pass
        args.output_file.write(' '.join(sequence) + '\n')
Example #3
0
def train(args):
    """A function that performs the "theanolm train" command.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments
    """

    numpy.random.seed(args.random_seed)

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        print("Invalid logging level requested:", args.log_level)
        sys.exit(1)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout,
                            format=log_format,
                            level=log_level)
    else:
        logging.basicConfig(filename=log_file,
                            format=log_format,
                            level=log_level)

    if args.debug:
        theano.config.compute_test_value = 'warn'
        logging.info("Enabled computing test values for tensor variables.")
        logging.warning("GpuArray backend will fail random number generation!")
    else:
        theano.config.compute_test_value = 'off'
    theano.config.profile = args.profile
    theano.config.profile_memory = args.profile

    with h5py.File(args.model_path, 'a', driver='core') as state:
        vocabulary = _read_vocabulary(args, state)

        if args.num_noise_samples > vocabulary.num_classes():
            print("Number of noise samples ({}) is larger than the number of "
                  "classes. This doesn't make sense and would cause unigram "
                  "sampling to fail.".format(args.num_noise_samples))
            sys.exit(1)

        num_training_files = len(args.training_set)
        if len(args.weights) > num_training_files:
            print("You specified more weights than training files.")
            sys.exit(1)
        weights = numpy.ones(num_training_files).astype(theano.config.floatX)
        for index, weight in enumerate(args.weights):
            weights[index] = weight
        if len(args.sampling) > num_training_files:
            print("You specified more sampling coefficients than training "
                  "files.")
            sys.exit(1)

        training_options = {
            'batch_size': args.batch_size,
            'sequence_length': args.sequence_length,
            'validation_frequency': args.validation_frequency,
            'patience': args.patience,
            'stopping_criterion': args.stopping_criterion,
            'max_epochs': args.max_epochs,
            'min_epochs': args.min_epochs,
            'max_annealing_count': args.max_annealing_count
        }
        optimization_options = {
            'method': args.optimization_method,
            'epsilon': args.numerical_stability_term,
            'gradient_decay_rate': args.gradient_decay_rate,
            'sqr_gradient_decay_rate': args.sqr_gradient_decay_rate,
            'learning_rate': args.learning_rate,
            'weights': weights,
            'momentum': args.momentum,
            'max_gradient_norm': args.gradient_normalization,
            'num_noise_samples': args.num_noise_samples,
            'noise_sharing': args.noise_sharing,
        }

        log_options(training_options, optimization_options, args)

        logging.info("Creating trainer.")
        trainer = Trainer(training_options, vocabulary, args.training_set,
                          args.sampling)
        trainer.set_logging(args.log_interval)

        logging.info("Building neural network.")
        if args.architecture == 'lstm300' or args.architecture == 'lstm1500':
            architecture = Architecture.from_package(args.architecture)
        else:
            with open(args.architecture, 'rt', encoding='utf-8') as arch_file:
                architecture = Architecture.from_description(arch_file)

        default_device = get_default_device(args.default_device)
        network = Network(architecture,
                          vocabulary,
                          trainer.class_prior_probs,
                          default_device=default_device,
                          profile=args.profile)

        network.set_sampling(args.noise_distribution, args.noise_dampening,
                             args.noise_sharing)

        logging.info("Building optimizer.")
        exclude_id = vocabulary.word_to_id['<unk>'] if args.exclude_unk \
                     else None
        epsilon = args.numerical_stability_term
        if args.cost == 'cross-entropy':
            cost_function = CrossEntropyCost(network, exclude_id,
                                             args.l1_regularization,
                                             args.l2_regularization, epsilon)
        elif args.cost == 'nce':
            cost_function = NCECost(network, exclude_id,
                                    args.l1_regularization,
                                    args.l2_regularization, epsilon)
        else:
            assert args.cost == 'blackout'
            cost_function = BlackoutCost(network, exclude_id,
                                         args.l1_regularization,
                                         args.l2_regularization, epsilon)
        try:
            optimizer = create_optimizer(optimization_options,
                                         network,
                                         cost_function,
                                         profile=args.profile)
        except theano.gradient.DisconnectedInputError as e:
            print("Cannot train the neural network because some of the "
                  "parameters are disconnected from the output. Make sure all "
                  "the layers are correctly connected in the network "
                  "architecture. The error message was: `{}´".format(e))

        if args.print_graph:
            print("Cost function computation graph:")
            theano.printing.debugprint(optimizer.gradient_update_function)

        trainer.initialize(network, state, optimizer, args.load_and_train)

        if args.validation_file is not None:
            logging.info("Building text scorer for cross-validation.")
            scorer = TextScorer(network,
                                use_shortlist=True,
                                exclude_unk=args.exclude_unk,
                                profile=args.profile)
            logging.info("Validation text: %s", args.validation_file.name)
            validation_mmap = mmap.mmap(args.validation_file.fileno(),
                                        0,
                                        prot=mmap.PROT_READ)
            validation_iter = \
                LinearBatchIterator(validation_mmap,
                                    vocabulary,
                                    batch_size=args.batch_size,
                                    max_sequence_length=args.sequence_length,
                                    map_oos_to_unk=False)
            trainer.set_validation(validation_iter, scorer)
        else:
            logging.info("Cross-validation will not be performed.")
            validation_iter = None

        logging.info("Training neural network.")
        trainer.train()

        if 'layers' not in state.keys():
            print("The model has not been trained. No cross-validations were "
                  "performed or training did not improve the model.")
        elif validation_iter is not None:
            network.set_state(state)
            perplexity = scorer.compute_perplexity(validation_iter)
            print("Best validation set perplexity:", perplexity)
Example #4
0
def decode(args):
    """A function that performs the "theanolm decode" command.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments
    """

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        print("Invalid logging level requested:", args.log_level,
              file=sys.stderr)
        sys.exit(1)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout, format=log_format, level=log_level)
    else:
        logging.basicConfig(filename=log_file, format=log_format, level=log_level)

    if args.debug:
        theano.config.compute_test_value = 'warn'
    else:
        theano.config.compute_test_value = 'off'
    theano.config.profile = args.profile
    theano.config.profile_memory = args.profile

    if (args.lattice_format == 'kaldi') or (args.output == 'kaldi'):
        if args.kaldi_vocabulary is None:
            print("Kaldi lattice vocabulary is not given.", file=sys.stderr)
            sys.exit(1)

    default_device = get_default_device(args.default_device)
    network = Network.from_file(args.model_path,
                                mode=Network.Mode(minibatch=False),
                                default_device=default_device)

    log_scale = 1.0 if args.log_base is None else numpy.log(args.log_base)
    if (args.log_base is not None) and (args.lattice_format == 'kaldi'):
        logging.info("Warning: Kaldi lattice reader doesn't support logarithm "
                     "base conversion.")

    if args.wi_penalty is None:
        wi_penalty = None
    else:
        wi_penalty = args.wi_penalty * log_scale
    decoding_options = {
        'nnlm_weight': args.nnlm_weight,
        'lm_scale': args.lm_scale,
        'wi_penalty': wi_penalty,
        'unk_penalty': args.unk_penalty,
        'use_shortlist': args.shortlist,
        'unk_from_lattice': args.unk_from_lattice,
        'linear_interpolation': args.linear_interpolation,
        'max_tokens_per_node': args.max_tokens_per_node,
        'beam': args.beam,
        'recombination_order': args.recombination_order,
        'prune_relative': args.prune_relative,
        'abs_min_max_tokens': args.abs_min_max_tokens,
        'abs_min_beam': args.abs_min_beam
    }
    logging.debug("DECODING OPTIONS")
    for option_name, option_value in decoding_options.items():
        logging.debug("%s: %s", option_name, str(option_value))

    logging.info("Building word lattice decoder.")
    decoder = LatticeDecoder(network, decoding_options)

    batch = LatticeBatch(args.lattices, args.lattice_list, args.lattice_format,
                         args.kaldi_vocabulary, args.num_jobs, args.job)
    for lattice_number, lattice in enumerate(batch):
        if lattice.utterance_id is None:
            lattice.utterance_id = str(lattice_number)
        logging.info("Utterance `%s´ -- %d of job %d",
                     lattice.utterance_id,
                     lattice_number + 1,
                     args.job)
        log_free_mem()

        final_tokens, recomb_tokens = decoder.decode(lattice)
        if (args.output == "slf") or (args.output == "kaldi"):
            rescored_lattice = RescoredLattice(lattice,
                                               final_tokens,
                                               recomb_tokens,
                                               network.vocabulary)
            rescored_lattice.lm_scale = args.lm_scale
            rescored_lattice.wi_penalty = args.wi_penalty
            if args.output == "slf":
                rescored_lattice.write_slf(args.output_file)
            else:
                assert args.output == "kaldi"
                rescored_lattice.write_kaldi(args.output_file,
                                             batch.kaldi_word_to_id)
        else:
            for token in final_tokens[:min(args.n_best, len(final_tokens))]:
                line = format_token(token,
                                    lattice.utterance_id,
                                    network.vocabulary,
                                    log_scale,
                                    args.output)
                args.output_file.write(line + "\n")
        gc.collect()