Esempio n. 1
0
def initialize_all(config, save_path, bokeh_name, params, bokeh_server, bokeh,
                   test_tag, use_load_ext, load_log, fast_start):
    root_path, extension = os.path.splitext(save_path)

    data = Data(**config['data'])
    train_conf = config['training']
    recognizer = create_model(config, data, test_tag)

    # Separate attention_params to be handled differently
    # when regularization is applied
    attention = recognizer.generator.transition.attention
    attention_params = Selector(attention).get_parameters().values()

    logger.info(
        "Initialization schemes for all bricks.\n"
        "Works well only in my branch with __repr__ added to all them,\n"
        "there is an issue #463 in Blocks to do that properly.")

    def show_init_scheme(cur):
        result = dict()
        for attr in dir(cur):
            if attr.endswith('_init'):
                result[attr] = getattr(cur, attr)
        for child in cur.children:
            result[child.name] = show_init_scheme(child)
        return result

    logger.info(pprint.pformat(show_init_scheme(recognizer)))

    prediction, prediction_mask = add_exploration(recognizer, data, train_conf)

    #
    # Observables:
    #
    primary_observables = []  # monitored each batch
    secondary_observables = []  # monitored every 10 batches
    validation_observables = []  # monitored on the validation set

    cg = recognizer.get_cost_graph(batch=True,
                                   prediction=prediction,
                                   prediction_mask=prediction_mask)
    labels, = VariableFilter(applications=[recognizer.cost], name='labels')(cg)
    labels_mask, = VariableFilter(applications=[recognizer.cost],
                                  name='labels_mask')(cg)

    gain_matrix = VariableFilter(
        theano_name=RewardRegressionEmitter.GAIN_MATRIX)(cg)
    if len(gain_matrix):
        gain_matrix, = gain_matrix
        primary_observables.append(rename(gain_matrix.min(), 'min_gain'))
        primary_observables.append(rename(gain_matrix.max(), 'max_gain'))

    batch_cost = cg.outputs[0].sum()
    batch_size = rename(recognizer.labels.shape[1], "batch_size")
    # Assumes constant batch size. `aggregation.mean` is not used because
    # of Blocks #514.
    cost = batch_cost / batch_size
    cost.name = "sequence_total_cost"
    logger.info("Cost graph is built")

    # Fetch variables useful for debugging.
    # It is important not to use any aggregation schemes here,
    # as it's currently impossible to spread the effect of
    # regularization on their variables, see Blocks #514.
    cost_cg = ComputationGraph(cost)
    r = recognizer
    energies, = VariableFilter(applications=[r.generator.readout.readout],
                               name="output_0")(cost_cg)
    bottom_output = VariableFilter(
        # We need name_regex instead of name because LookupTable calls itsoutput output_0
        applications=[r.bottom.apply],
        name_regex="output")(cost_cg)[-1]
    attended, = VariableFilter(applications=[r.generator.transition.apply],
                               name="attended")(cost_cg)
    attended_mask, = VariableFilter(applications=[
        r.generator.transition.apply
    ],
                                    name="attended_mask")(cost_cg)
    weights, = VariableFilter(applications=[r.generator.evaluate],
                              name="weights")(cost_cg)

    from blocks.roles import AUXILIARY
    l2_cost, = VariableFilter(roles=[AUXILIARY],
                              theano_name='l2_cost_aux')(cost_cg)
    cost_forward, = VariableFilter(roles=[AUXILIARY],
                                   theano_name='costs_forward_aux')(cost_cg)

    max_recording_length = rename(bottom_output.shape[0],
                                  "max_recording_length")
    # To exclude subsampling related bugs
    max_attended_mask_length = rename(attended_mask.shape[0],
                                      "max_attended_mask_length")
    max_attended_length = rename(attended.shape[0], "max_attended_length")
    max_num_phonemes = rename(labels.shape[0], "max_num_phonemes")
    min_energy = rename(energies.min(), "min_energy")
    max_energy = rename(energies.max(), "max_energy")
    mean_attended = rename(abs(attended).mean(), "mean_attended")
    mean_bottom_output = rename(
        abs(bottom_output).mean(), "mean_bottom_output")
    weights_penalty = rename(monotonicity_penalty(weights, labels_mask),
                             "weights_penalty")
    weights_entropy = rename(entropy(weights, labels_mask), "weights_entropy")
    mask_density = rename(labels_mask.mean(), "mask_density")
    cg = ComputationGraph([
        cost, weights_penalty, weights_entropy, min_energy, max_energy,
        mean_attended, mean_bottom_output, batch_size, max_num_phonemes,
        mask_density
    ])
    # Regularization. It is applied explicitly to all variables
    # of interest, it could not be applied to the cost only as it
    # would not have effect on auxiliary variables, see Blocks #514.
    reg_config = config.get('regularization', dict())
    regularized_cg = cg
    if reg_config.get('dropout'):
        logger.info('apply dropout')
        regularized_cg = apply_dropout(cg, [bottom_output], 0.5)
    if reg_config.get('noise'):
        logger.info('apply noise')
        noise_subjects = [
            p for p in cg.parameters if p not in attention_params
        ]
        regularized_cg = apply_noise(cg, noise_subjects, reg_config['noise'])

    train_cost = regularized_cg.outputs[0]
    if reg_config.get("penalty_coof", .0) > 0:
        # big warning!!!
        # here we assume that:
        # regularized_weights_penalty = regularized_cg.outputs[1]
        train_cost = (train_cost + reg_config.get("penalty_coof", .0) *
                      regularized_cg.outputs[1] / batch_size)
    if reg_config.get("decay", .0) > 0:
        train_cost = (
            train_cost + reg_config.get("decay", .0) *
            l2_norm(VariableFilter(roles=[WEIGHT])(cg.parameters))**2)

    train_cost = rename(train_cost, 'train_cost')

    gradients = None
    if reg_config.get('adaptive_noise'):
        logger.info('apply adaptive noise')
        if ((reg_config.get("penalty_coof", .0) > 0)
                or (reg_config.get("decay", .0) > 0)):
            logger.error('using  adaptive noise with alignment weight panalty '
                         'or weight decay is probably stupid')
        train_cost, regularized_cg, gradients, noise_brick = apply_adaptive_noise(
            cg,
            cg.outputs[0],
            variables=cg.parameters,
            num_examples=data.get_dataset('train').num_examples,
            parameters=Model(
                regularized_cg.outputs[0]).get_parameter_dict().values(),
            **reg_config.get('adaptive_noise'))
        train_cost.name = 'train_cost'
        adapt_noise_cg = ComputationGraph(train_cost)
        model_prior_mean = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_mean')(adapt_noise_cg)[0],
            'model_prior_mean')
        model_cost = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_cost')(adapt_noise_cg)[0], 'model_cost')
        model_prior_variance = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_variance')(adapt_noise_cg)[0],
            'model_prior_variance')
        regularized_cg = ComputationGraph(
            [train_cost, model_cost] + regularized_cg.outputs +
            [model_prior_mean, model_prior_variance])
        primary_observables += [
            regularized_cg.outputs[1],  # model cost
            regularized_cg.outputs[2],  # task cost
            regularized_cg.outputs[-2],  # model prior mean
            regularized_cg.outputs[-1]
        ]  # model prior variance

    model = Model(train_cost)
    if params:
        logger.info("Load parameters from " + params)
        # please note: we cannot use recognizer.load_params
        # as it builds a new computation graph that dies not have
        # shapred variables added by adaptive weight noise
        with open(params, 'r') as src:
            param_values = load_parameters(src)
        model.set_parameter_values(param_values)

    parameters = model.get_parameter_dict()
    logger.info("Parameters:\n" +
                pprint.pformat([(key, parameters[key].get_value().shape)
                                for key in sorted(parameters.keys())],
                               width=120))

    # Define the training algorithm.
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'],
                                   train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(
            AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    max_norm_rules = []
    if reg_config.get('max_norm', False) > 0:
        logger.info("Apply MaxNorm")
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(cg.parameters)
        if reg_config.get('max_norm_exclude_lookup', False):
            maxnorm_subjects = [
                v for v in maxnorm_subjects
                if not isinstance(get_brick(v), LookupTable)
            ]
        logger.info("Parameters covered by MaxNorm:\n" + pprint.pformat(
            [name for name, p in parameters.items() if p in maxnorm_subjects]))
        logger.info("Parameters NOT covered by MaxNorm:\n" + pprint.pformat([
            name for name, p in parameters.items() if not p in maxnorm_subjects
        ]))
        max_norm_rules = [
            Restrict(VariableClipping(reg_config['max_norm'], axis=0),
                     maxnorm_subjects)
        ]
    burn_in = []
    if train_conf.get('burn_in_steps', 0):
        burn_in.append(BurnIn(num_steps=train_conf['burn_in_steps']))
    algorithm = GradientDescent(
        cost=train_cost,
        parameters=parameters.values(),
        gradients=gradients,
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)] + burn_in),
        on_unused_sources='warn')

    logger.debug("Scan Ops in the gradients")
    gradient_cg = ComputationGraph(algorithm.gradients.values())
    for op in ComputationGraph(gradient_cg).scans:
        logger.debug(op)

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    secondary_observables += list(regularized_cg.outputs)
    if not 'train_cost' in [v.name for v in secondary_observables]:
        secondary_observables += [train_cost]
    secondary_observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold
    ]
    for name, param in parameters.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements**0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements**0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements**0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        secondary_observables.append(stats)

    primary_observables += [
        train_cost, algorithm.total_gradient_norm, algorithm.total_step_norm,
        clipping.threshold, max_recording_length, max_attended_length,
        max_attended_mask_length
    ]

    validation_observables += [
        rename(aggregation.mean(batch_cost, batch_size), cost.name),
        rename(aggregation.sum_(batch_size), 'num_utterances'),
        weights_entropy, weights_penalty
    ]

    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name == 'weights_penalty':
                result.append(
                    rename(aggregation.mean(var, batch_size),
                           'weights_penalty_per_recording'))
            elif var.name == 'weights_entropy':
                result.append(
                    rename(aggregation.mean(var, labels_mask.sum()),
                           'weights_entropy_per_label'))
            else:
                result.append(var)
        return result

    mon_conf = config['monitoring']

    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(
            Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
    ]
    extensions.append(
        TrainingDataMonitoring(primary_observables + [l2_cost, cost_forward],
                               after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(secondary_observables),
        prefix="average",
        every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes(validation_observables +
                                   [l2_cost, cost_forward]),
        data.get_stream("valid", shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['validate_every_epochs'],
            every_n_batches=mon_conf['validate_every_batches'],
            after_training=False)
    extensions.append(validation)
    per = PhonemeErrorRate(recognizer, data, **config['monitoring']['search'])
    per_monitoring = DataStreamMonitoring(
        [per],
        data.get_stream("valid", batches=False, shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['search_every_epochs'],
            every_n_batches=mon_conf['search_every_batches'],
            after_training=False)
    extensions.append(per_monitoring)
    track_the_best_per = TrackTheBest(
        per_monitoring.record_name(per)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_cost = TrackTheBest(
        validation.record_name(cost)).set_conditions(before_first_epoch=True,
                                                     after_epoch=True)
    extensions += [track_the_best_cost, track_the_best_per]
    extensions.append(
        AdaptiveClipping(algorithm.total_gradient_norm.name,
                         clipping,
                         train_conf['gradient_threshold'],
                         decay_rate=0.998,
                         burnin_period=500))
    extensions += [
        SwitchOffLengthFilter(
            data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf.get('num_batches'),
                    after_n_epochs=train_conf.get('num_epochs')).add_condition(
                        ["after_batch"], _gradient_norm_is_none),
    ]
    channels = [
        # Plot 1: training and validation costs
        [
            average_monitoring.record_name(train_cost),
            validation.record_name(cost)
        ],
        # Plot 2: gradient norm,
        [
            average_monitoring.record_name(algorithm.total_gradient_norm),
            average_monitoring.record_name(clipping.threshold)
        ],
        # Plot 3: phoneme error rate
        [per_monitoring.record_name(per)],
        # Plot 4: training and validation mean weight entropy
        [
            average_monitoring._record_name('weights_entropy_per_label'),
            validation._record_name('weights_entropy_per_label')
        ],
        # Plot 5: training and validation monotonicity penalty
        [
            average_monitoring._record_name('weights_penalty_per_recording'),
            validation._record_name('weights_penalty_per_recording')
        ]
    ]
    if bokeh:
        extensions += [
            Plot(bokeh_name if bokeh_name else os.path.basename(save_path),
                 channels,
                 every_n_batches=10,
                 server_url=bokeh_server),
        ]
    extensions += [
        Checkpoint(save_path,
                   before_first_epoch=not fast_start,
                   after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True).add_condition(
                       ['after_epoch'],
                       OnLogRecord(track_the_best_per.notification_name),
                       (root_path + "_best" + extension, )).add_condition(
                           ['after_epoch'],
                           OnLogRecord(track_the_best_cost.notification_name),
                           (root_path + "_best_ll" + extension, )),
        ProgressBar()
    ]
    extensions.append(EmbedIPython(use_main_loop_run_caller_env=True))
    if config['net']['criterion']['name'].startswith('mse'):
        extensions.append(
            LogInputsGains(labels, cg, recognizer.generator.readout.emitter,
                           data))

    if train_conf.get('patience'):
        patience_conf = train_conf['patience']
        if not patience_conf.get('notification_names'):
            # setdefault will not work for empty list
            patience_conf['notification_names'] = [
                track_the_best_per.notification_name,
                track_the_best_cost.notification_name
            ]
        extensions.append(Patience(**patience_conf))

    extensions.append(
        Printing(every_n_batches=1, attribute_filter=PrintingFilterList()))

    return model, algorithm, data, extensions
Esempio n. 2
0
def search(config, params, load_path, part, decode_only, report, decoded_save,
           nll_only, seed):
    import matplotlib
    matplotlib.use("Agg")
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])
    search_conf = config['monitoring']['search']

    logger.info("Recognizer initialization started")
    recognizer = create_model(config, data, load_path)
    recognizer.init_beam_search(search_conf['beam_size'])
    logger.info("Recognizer is initialized")

    has_uttids = 'uttids' in data.info_dataset.provides_sources
    add_sources = ('uttids', ) if has_uttids else ()
    dataset = data.get_dataset(part, add_sources)
    stream = data.get_stream(part,
                             batches=False,
                             shuffle=part == 'train',
                             add_sources=add_sources,
                             num_examples=500 if part == 'train' else None,
                             seed=seed)
    it = stream.get_epoch_iterator(as_dict=True)
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function([weights], [
        weights_std(weights.dimshuffle(0, 'x', 1)),
        monotonicity_penalty(weights.dimshuffle(0, 'x', 1))
    ])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.

    if config.get('vocabulary'):
        with open(os.path.expandvars(config['vocabulary'])) as f:
            vocabulary = dict(line.split() for line in f.readlines())

        def to_words(chars):
            words = chars.split()
            words = [
                vocabulary[word] if word in vocabulary else vocabulary['<UNK>']
                for word in words
            ]
            return words

    for number, example in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        uttids = example.pop('uttids', None)
        raw_groundtruth = example.pop('labels')
        required_inputs = dict_subset(example, recognizer.inputs.keys())

        print("Utterance {} ({})".format(number, uttids), file=print_to)

        groundtruth = dataset.decode(raw_groundtruth)
        groundtruth_text = dataset.pretty_print(raw_groundtruth, example)
        costs_groundtruth, weights_groundtruth = recognizer.analyze(
            inputs=required_inputs,
            groundtruth=raw_groundtruth,
            prediction=raw_groundtruth)[:2]
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:",
              mono_penalty_groundtruth,
              file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
              file=print_to)
        if nll_only:
            print_to.flush()
            continue

        before = time.time()
        try:
            search_kwargs = dict(
                char_discount=search_conf.get('char_discount'),
                round_to_inf=search_conf.get('round_to_inf'),
                stop_on=search_conf.get('stop_on'),
                validate_solution_function=getattr(data.info_dataset,
                                                   'validate_solution', None))
            search_kwargs = {k: v for k, v in search_kwargs.items() if v}
            outputs, search_costs = recognizer.beam_search(
                required_inputs, **search_kwargs)
        except CandidateNotFoundError:
            logger.error('Candidate not found!')
            outputs = [[]]
            search_costs = [[numpy.NaN]]

        took = time.time() - before
        recognized = dataset.decode(outputs[0])
        recognized_text = dataset.pretty_print(outputs[0], example)
        if recognized:
            # Theano scan doesn't work with 0 length sequences
            costs_recognized, weights_recognized = recognizer.analyze(
                inputs=required_inputs,
                groundtruth=raw_groundtruth,
                prediction=outputs[0])[:2]
            weight_std_recognized, mono_penalty_recognized = weight_statistics(
                weights_recognized)
            error = min(1, wer(groundtruth, recognized))
        else:
            error = 1
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        if config.get('vocabulary'):
            wer_error = min(
                1, wer(to_words(groundtruth_text), to_words(recognized_text)))
            total_wer_errors += len(groundtruth) * wer_error
            total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(uttids, ' '.join(recognized)),
                  file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        if recognized:
            print("Recognized cost:", costs_recognized.sum(), file=print_to)
            print("Recognized weight std:",
                  weight_std_recognized,
                  file=print_to)
            print("Recognized monotonicity penalty:",
                  mono_penalty_recognized,
                  file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        if config.get('vocabulary'):
            print("WER:", wer_error, file=print_to)
            print("Average WER:",
                  total_wer_errors / total_word_length,
                  file=print_to)
        print_to.flush()
Esempio n. 3
0
def search(config, params, load_path, beam_size, part, decode_only, report,
           decoded_save, nll_only, char_discount):
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])

    recognizer = SpeechRecognizer(
        data.recordings_source, data.labels_source,
        data.eos_label, data.num_features, data.num_labels,
        character_map=data.character_map,
        name='recognizer', **config["net"])
    recognizer.load_params(load_path)
    recognizer.init_beam_search(beam_size)

    dataset = data.get_dataset(part, add_sources=(data.uttid_source,))
    stream = data.get_stream(part, batches=False, shuffle=False,
                                add_sources=(data.uttid_source,))
    it = stream.get_epoch_iterator()
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function(
        [weights],
        [weights_std(weights.dimshuffle(0, 'x', 1)),
            monotonicity_penalty(weights.dimshuffle(0, 'x', 1))])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.
    with open(os.path.expandvars(config['vocabulary'])) as f:
        vocabulary = dict(line.split() for line in f.readlines())

    def to_words(chars):
        words = chars.split()
        words = [vocabulary[word] if word in vocabulary
                    else vocabulary['<UNK>'] for word in words]
        return words

    for number, data in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        print("Utterance {} ({})".format(number, data[2]), file=print_to)
        groundtruth = dataset.decode(data[1])
        groundtruth_text = dataset.pretty_print(data[1])
        costs_groundtruth, weights_groundtruth = (
            recognizer.analyze(data[0], data[1])[:2])
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:", mono_penalty_groundtruth, file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
                file=print_to)
        if nll_only:
            continue

        before = time.time()
        outputs, search_costs = recognizer.beam_search(
            data[0], char_discount=char_discount)
        took = time.time() - before
        recognized = dataset.decode(outputs[0])
        recognized_text = dataset.pretty_print(outputs[0])
        costs_recognized, weights_recognized = (
            recognizer.analyze(data[0], outputs[0])[:2])
        weight_std_recognized, mono_penalty_recognized = weight_statistics(
            weights_recognized)
        error = min(1, wer(groundtruth, recognized))
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        wer_error = min(1, wer(to_words(groundtruth_text),
                                to_words(recognized_text)))
        total_wer_errors += len(groundtruth) * wer_error
        total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(data[2], ' '.join(recognized)), file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        print("Recognized cost:", costs_recognized.sum(), file=print_to)
        print("Recognized weight std:", weight_std_recognized, file=print_to)
        print("Recognized monotonicity penalty:", mono_penalty_recognized, file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        print("WER:", wer_error, file=print_to)
        print("Average WER:", total_wer_errors / total_word_length, file=print_to)
Esempio n. 4
0
def train(config, save_path, bokeh_name,
          params, bokeh_server, test_tag, use_load_ext,
          load_log, fast_start, validation_epochs, validation_batches,
          per_epochs, per_batches):
    root_path, extension = os.path.splitext(save_path)

    data = Data(**config['data'])

    # Build the main brick and initialize all parameters.
    recognizer = SpeechRecognizer(
        data.recordings_source, data.labels_source,
        data.eos_label,
        data.num_features, data.num_labels,
        name="recognizer",
        data_prepend_eos=data.prepend_eos,
        character_map=data.character_map,
        **config["net"])
    for brick_path, attribute_dict in sorted(
            config['initialization'].items(),
            key=lambda (k, v): -k.count('/')):
        for attribute, value in attribute_dict.items():
            brick, = Selector(recognizer).select(brick_path).bricks
            setattr(brick, attribute, value)
            brick.push_initialization_config()
    recognizer.initialize()

    # Separate attention_params to be handled differently
    # when regularization is applied
    attention = recognizer.generator.transition.attention
    attention_params = Selector(attention).get_parameters().values()

    logger.info(
        "Initialization schemes for all bricks.\n"
        "Works well only in my branch with __repr__ added to all them,\n"
        "there is an issue #463 in Blocks to do that properly.")

    def show_init_scheme(cur):
        result = dict()
        for attr in dir(cur):
            if attr.endswith('_init'):
                result[attr] = getattr(cur, attr)
        for child in cur.children:
            result[child.name] = show_init_scheme(child)
        return result
    logger.info(pprint.pformat(show_init_scheme(recognizer)))

    if params:
        logger.info("Load parameters from " + params)
        recognizer.load_params(params)

    if test_tag:
        tensor.TensorVariable.__str__ = tensor.TensorVariable.__repr__
        __stream = data.get_stream("train")
        __data = next(__stream.get_epoch_iterator(as_dict=True))
        recognizer.recordings.tag.test_value = __data[data.recordings_source]
        recognizer.recordings_mask.tag.test_value = __data[data.recordings_source + '_mask']
        recognizer.labels.tag.test_value = __data[data.labels_source]
        recognizer.labels_mask.tag.test_value = __data[data.labels_source + '_mask']
        theano.config.compute_test_value = 'warn'

    batch_cost = recognizer.get_cost_graph().sum()
    batch_size = named_copy(recognizer.recordings.shape[1], "batch_size")
    # Assumes constant batch size. `aggregation.mean` is not used because
    # of Blocks #514.
    cost = batch_cost / batch_size
    cost.name = "sequence_log_likelihood"
    logger.info("Cost graph is built")

    # Fetch variables useful for debugging.
    # It is important not to use any aggregation schemes here,
    # as it's currently impossible to spread the effect of
    # regularization on their variables, see Blocks #514.
    cost_cg = ComputationGraph(cost)
    r = recognizer
    energies, = VariableFilter(
        applications=[r.generator.readout.readout], name="output_0")(
                cost_cg)
    bottom_output, = VariableFilter(
        applications=[r.bottom.apply], name="output")(
                cost_cg)
    attended, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended")(
                cost_cg)
    attended_mask, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended_mask")(
                cost_cg)
    weights, = VariableFilter(
        applications=[r.generator.evaluate], name="weights")(
                cost_cg)
    max_recording_length = named_copy(r.recordings.shape[0],
                                      "max_recording_length")
    # To exclude subsampling related bugs
    max_attended_mask_length = named_copy(attended_mask.shape[0],
                                          "max_attended_mask_length")
    max_attended_length = named_copy(attended.shape[0],
                                     "max_attended_length")
    max_num_phonemes = named_copy(r.labels.shape[0],
                                  "max_num_phonemes")
    min_energy = named_copy(energies.min(), "min_energy")
    max_energy = named_copy(energies.max(), "max_energy")
    mean_attended = named_copy(abs(attended).mean(),
                               "mean_attended")
    mean_bottom_output = named_copy(abs(bottom_output).mean(),
                                    "mean_bottom_output")
    weights_penalty = named_copy(monotonicity_penalty(weights, r.labels_mask),
                                 "weights_penalty")
    weights_entropy = named_copy(entropy(weights, r.labels_mask),
                                 "weights_entropy")
    mask_density = named_copy(r.labels_mask.mean(),
                              "mask_density")
    cg = ComputationGraph([
        cost, weights_penalty, weights_entropy,
        min_energy, max_energy,
        mean_attended, mean_bottom_output,
        batch_size, max_num_phonemes,
        mask_density])

    # Regularization. It is applied explicitly to all variables
    # of interest, it could not be applied to the cost only as it
    # would not have effect on auxiliary variables, see Blocks #514.
    reg_config = config['regularization']
    regularized_cg = cg
    if reg_config.get('dropout'):
        logger.info('apply dropout')
        regularized_cg = apply_dropout(cg, [bottom_output], 0.5)
    if reg_config.get('noise'):
        logger.info('apply noise')
        noise_subjects = [p for p in cg.parameters if p not in attention_params]
        regularized_cg = apply_noise(cg, noise_subjects, reg_config['noise'])
    regularized_cost = regularized_cg.outputs[0]
    regularized_weights_penalty = regularized_cg.outputs[1]

    # Model is weird class, we spend lots of time arguing with Bart
    # what it should be. However it can already nice things, e.g.
    # one extract all the parameters from the computation graphs
    # and give them hierahical names. This help to notice when a
    # because of some bug a parameter is not in the computation
    # graph.
    model = SpeechModel(regularized_cost)
    params = model.get_parameter_dict()
    logger.info("Parameters:\n" +
                pprint.pformat(
                    [(key, params[key].get_value().shape) for key
                        in sorted(params.keys())],
                    width=120))

    # Define the training algorithm.
    train_conf = config['training']
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'], train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    max_norm_rules = []
    if reg_config.get('max_norm', False):
        logger.info("Apply MaxNorm")
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(cg.parameters)
        if reg_config.get('max_norm_exclude_lookup', False):
            maxnorm_subjects = [v for v in maxnorm_subjects
                                if not isinstance(get_brick(v), LookupTable)]
        logger.info("Parameters covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in params.items()
                                        if p in maxnorm_subjects]))
        logger.info("Parameters NOT covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in params.items()
                                        if not p in maxnorm_subjects]))
        max_norm_rules = [
            Restrict(VariableClipping(reg_config['max_norm'], axis=0),
                        maxnorm_subjects)]
    algorithm = GradientDescent(
        cost=regularized_cost +
            reg_config.get("penalty_coof", .0) * regularized_weights_penalty / batch_size +
            reg_config.get("decay", .0) *
            l2_norm(VariableFilter(roles=[WEIGHT])(cg.parameters)) ** 2,
        parameters=params.values(),
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)]))

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    observables = regularized_cg.outputs
    observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold]
    for name, param in params.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements ** 0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements ** 0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements ** 0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        observables.append(stats)

    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name == 'weights_penalty':
                result.append(named_copy(aggregation.mean(var, batch_size),
                                            'weights_penalty_per_recording'))
            elif var.name == 'weights_entropy':
                result.append(named_copy(aggregation.mean(
                    var, recognizer.labels_mask.sum()), 'weights_entropy_per_label'))
            else:
                result.append(var)
        return result

    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
        ]
    extensions.append(TrainingDataMonitoring(
        [observables[0], algorithm.total_gradient_norm,
            algorithm.total_step_norm, clipping.threshold,
            max_recording_length,
            max_attended_length, max_attended_mask_length], after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(observables),
        prefix="average", every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes([cost, weights_entropy, weights_penalty]),
        data.get_stream("valid"), prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=validation_epochs,
            every_n_batches=validation_batches,
            after_training=False)
    extensions.append(validation)
    recognizer.init_beam_search(10)
    per = PhonemeErrorRate(recognizer, data.get_dataset("valid"))
    per_monitoring = DataStreamMonitoring(
        [per], data.get_stream("valid", batches=False, shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=per_epochs,
            every_n_batches=per_batches,
            after_training=False)
    extensions.append(per_monitoring)
    track_the_best_per = TrackTheBest(
        per_monitoring.record_name(per)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_likelihood = TrackTheBest(
        validation.record_name(cost)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    extensions += [track_the_best_likelihood, track_the_best_per]
    extensions.append(AdaptiveClipping(
        algorithm.total_gradient_norm.name,
        clipping, train_conf['gradient_threshold'],
        decay_rate=0.998, burnin_period=500))
    extensions += [
        SwitchOffLengthFilter(data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf['num_batches'],
                    after_n_epochs=train_conf['num_epochs'])
        .add_condition(["after_batch"], _gradient_norm_is_none),
        # Live plotting: requires launching `bokeh-server`
        # and allows to see what happens online.
        Plot(bokeh_name
             if bokeh_name
             else os.path.basename(save_path),
             [# Plot 1: training and validation costs
             [average_monitoring.record_name(regularized_cost),
             validation.record_name(cost)],
             # Plot 2: gradient norm,
             [average_monitoring.record_name(algorithm.total_gradient_norm),
             average_monitoring.record_name(clipping.threshold)],
             # Plot 3: phoneme error rate
             [per_monitoring.record_name(per)],
             # Plot 4: training and validation mean weight entropy
             [average_monitoring._record_name('weights_entropy_per_label'),
             validation._record_name('weights_entropy_per_label')],
             # Plot 5: training and validation monotonicity penalty
             [average_monitoring._record_name('weights_penalty_per_recording'),
             validation._record_name('weights_penalty_per_recording')]],
             every_n_batches=10,
             server_url=bokeh_server),
        Checkpoint(save_path,
                   before_first_epoch=not fast_start, after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True)
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_per.notification_name),
            (root_path + "_best" + extension,))
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_likelihood.notification_name),
            (root_path + "_best_ll" + extension,)),
        ProgressBar(),
        Printing(every_n_batches=1,
                    attribute_filter=PrintingFilterList()
                    )]

    # Save the config into the status
    log = TrainingLog()
    log.status['_config'] = repr(config)
    main_loop = MainLoop(
        model=model, log=log, algorithm=algorithm,
        data_stream=data.get_stream("train"),
        extensions=extensions)
    main_loop.run()
Esempio n. 5
0
def search(config, params, load_path, part, decode_only, report, decoded_save,
           nll_only, seed):
    import matplotlib
    matplotlib.use("Agg")
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])
    search_conf = config['monitoring']['search']

    logger.info("Recognizer initialization started")
    recognizer = SpeechRecognizer(data.recordings_source,
                                  data.labels_source,
                                  data.eos_label,
                                  data.num_features,
                                  data.num_labels,
                                  character_map=data.character_map,
                                  name='recognizer',
                                  **config["net"])
    recognizer.load_params(load_path)
    recognizer.init_beam_search(search_conf['beam_size'])
    logger.info("Recognizer is initialized")

    stream = data.get_stream(part,
                             batches=False,
                             shuffle=part == 'train',
                             add_sources=(data.uttid_source, ),
                             num_examples=500 if part == 'train' else None,
                             seed=seed)
    it = stream.get_epoch_iterator()
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function([weights], [
        weights_std(weights.dimshuffle(0, 'x', 1)),
        monotonicity_penalty(weights.dimshuffle(0, 'x', 1))
    ])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.

    if config.get('vocabulary'):
        with open(os.path.expandvars(config['vocabulary'])) as f:
            vocabulary = dict(line.split() for line in f.readlines())

        def to_words(chars):
            words = chars.split()
            words = [
                vocabulary[word] if word in vocabulary else vocabulary['<UNK>']
                for word in words
            ]
            return words

    for number, example in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        print("Utterance {} ({})".format(number, example[2]), file=print_to)
        groundtruth = data.decode(example[1])
        groundtruth_text = data.pretty_print(example[1])
        costs_groundtruth, weights_groundtruth = (recognizer.analyze(
            example[0], example[1], example[1])[:2])
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:",
              mono_penalty_groundtruth,
              file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
              file=print_to)
        if nll_only:
            print_to.flush()
            continue

        before = time.time()
        outputs, search_costs = recognizer.beam_search(
            example[0],
            char_discount=search_conf['char_discount'],
            round_to_inf=search_conf['round_to_inf'],
            stop_on=search_conf['stop_on'])
        took = time.time() - before
        recognized = data.decode(outputs[0])
        recognized_text = data.pretty_print(outputs[0])
        if recognized:
            # Theano scan doesn't work with 0 length sequences
            costs_recognized, weights_recognized = (recognizer.analyze(
                example[0], example[1], outputs[0])[:2])
            weight_std_recognized, mono_penalty_recognized = weight_statistics(
                weights_recognized)
            error = min(1, wer(groundtruth, recognized))
        else:
            error = 1
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        if config.get('vocabulary'):
            wer_error = min(
                1, wer(to_words(groundtruth_text), to_words(recognized_text)))
            total_wer_errors += len(groundtruth) * wer_error
            total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(example[2], ' '.join(recognized)),
                  file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        if recognized:
            print("Recognized cost:", costs_recognized.sum(), file=print_to)
            print("Recognized weight std:",
                  weight_std_recognized,
                  file=print_to)
            print("Recognized monotonicity penalty:",
                  mono_penalty_recognized,
                  file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        if config.get('vocabulary'):
            print("WER:", wer_error, file=print_to)
            print("Average WER:",
                  total_wer_errors / total_word_length,
                  file=print_to)
        print_to.flush()
Esempio n. 6
0
def search(config, params, load_path, part, decode_only, report,
           decoded_save, nll_only, seed):
    import matplotlib
    matplotlib.use("Agg")
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])
    search_conf = config['monitoring']['search']

    logger.info("Recognizer initialization started")
    recognizer = create_model(config, data, load_path)
    recognizer.init_beam_search(search_conf['beam_size'])
    logger.info("Recognizer is initialized")

    has_uttids = 'uttids' in data.info_dataset.provides_sources
    add_sources = ('uttids',) if has_uttids else ()
    dataset = data.get_dataset(part, add_sources)
    stream = data.get_stream(part, batches=False,
                             shuffle=part == 'train',
                             add_sources=add_sources,
                             num_examples=500 if part == 'train' else None,
                             seed=seed)
    it = stream.get_epoch_iterator(as_dict=True)
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function(
        [weights],
        [weights_std(weights.dimshuffle(0, 'x', 1)),
            monotonicity_penalty(weights.dimshuffle(0, 'x', 1))])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.

    if config.get('vocabulary'):
        with open(os.path.expandvars(config['vocabulary'])) as f:
            vocabulary = dict(line.split() for line in f.readlines())

        def to_words(chars):
            words = chars.split()
            words = [vocabulary[word] if word in vocabulary
                     else vocabulary['<UNK>'] for word in words]
            return words

    for number, example in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        uttids = example.pop('uttids', None)
        raw_groundtruth = example.pop('labels')
        required_inputs = dict_subset(example, recognizer.inputs.keys())

        print("Utterance {} ({})".format(number, uttids), file=print_to)

        groundtruth = dataset.decode(raw_groundtruth)
        groundtruth_text = dataset.pretty_print(raw_groundtruth, example)
        costs_groundtruth, weights_groundtruth = recognizer.analyze(
            inputs=required_inputs,
            groundtruth=raw_groundtruth,
            prediction=raw_groundtruth)[:2]
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:", mono_penalty_groundtruth,
              file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
              file=print_to)
        if nll_only:
            print_to.flush()
            continue

        before = time.time()
        try:
            search_kwargs = dict(
                char_discount=search_conf.get('char_discount'),
                round_to_inf=search_conf.get('round_to_inf'),
                stop_on=search_conf.get('stop_on'),
                validate_solution_function=getattr(
                    data.info_dataset, 'validate_solution', None))
            search_kwargs = {k: v for k, v in search_kwargs.items() if v}
            outputs, search_costs = recognizer.beam_search(
                required_inputs, **search_kwargs)
        except CandidateNotFoundError:
            logger.error('Candidate not found!')
            outputs = [[]]
            search_costs = [[numpy.NaN]]

        took = time.time() - before
        recognized = dataset.decode(outputs[0])
        recognized_text = dataset.pretty_print(outputs[0], example)
        if recognized:
            # Theano scan doesn't work with 0 length sequences
            costs_recognized, weights_recognized = recognizer.analyze(
                inputs=required_inputs,
                groundtruth=raw_groundtruth,
                prediction=outputs[0])[:2]
            weight_std_recognized, mono_penalty_recognized = weight_statistics(
                weights_recognized)
            error = min(1, wer(groundtruth, recognized))
        else:
            error = 1
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        if config.get('vocabulary'):
            wer_error = min(1, wer(to_words(groundtruth_text),
                                   to_words(recognized_text)))
            total_wer_errors += len(groundtruth) * wer_error
            total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(uttids, ' '.join(recognized)),
                  file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        if recognized:
            print("Recognized cost:", costs_recognized.sum(), file=print_to)
            print("Recognized weight std:", weight_std_recognized,
                  file=print_to)
            print("Recognized monotonicity penalty:", mono_penalty_recognized,
                  file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        if config.get('vocabulary'):
            print("WER:", wer_error, file=print_to)
            print("Average WER:", total_wer_errors / total_word_length, file=print_to)
        print_to.flush()
Esempio n. 7
0
def initialize_all(config, save_path, bokeh_name,
                   params, bokeh_server, bokeh, test_tag, use_load_ext,
                   load_log, fast_start):
    root_path, extension = os.path.splitext(save_path)

    data = Data(**config['data'])
    train_conf = config['training']
    recognizer = create_model(config, data, test_tag)

    # Separate attention_params to be handled differently
    # when regularization is applied
    attention = recognizer.generator.transition.attention
    attention_params = Selector(attention).get_parameters().values()

    logger.info(
        "Initialization schemes for all bricks.\n"
        "Works well only in my branch with __repr__ added to all them,\n"
        "there is an issue #463 in Blocks to do that properly.")

    def show_init_scheme(cur):
        result = dict()
        for attr in dir(cur):
            if attr.endswith('_init'):
                result[attr] = getattr(cur, attr)
        for child in cur.children:
            result[child.name] = show_init_scheme(child)
        return result
    logger.info(pprint.pformat(show_init_scheme(recognizer)))

    prediction, prediction_mask = add_exploration(recognizer, data, train_conf)

    #
    # Observables:
    #
    primary_observables = []  # monitored each batch
    secondary_observables = []  # monitored every 10 batches
    validation_observables = []  # monitored on the validation set

    cg = recognizer.get_cost_graph(
        batch=True, prediction=prediction, prediction_mask=prediction_mask)
    labels, = VariableFilter(
        applications=[recognizer.cost], name='labels')(cg)
    labels_mask, = VariableFilter(
        applications=[recognizer.cost], name='labels_mask')(cg)

    gain_matrix = VariableFilter(
        theano_name=RewardRegressionEmitter.GAIN_MATRIX)(cg)
    if len(gain_matrix):
        gain_matrix, = gain_matrix
        primary_observables.append(
            rename(gain_matrix.min(), 'min_gain'))
        primary_observables.append(
            rename(gain_matrix.max(), 'max_gain'))

    batch_cost = cg.outputs[0].sum()
    batch_size = rename(recognizer.labels.shape[1], "batch_size")
    # Assumes constant batch size. `aggregation.mean` is not used because
    # of Blocks #514.
    cost = batch_cost / batch_size
    cost.name = "sequence_total_cost"
    logger.info("Cost graph is built")

    # Fetch variables useful for debugging.
    # It is important not to use any aggregation schemes here,
    # as it's currently impossible to spread the effect of
    # regularization on their variables, see Blocks #514.
    cost_cg = ComputationGraph(cost)
    r = recognizer
    energies, = VariableFilter(
        applications=[r.generator.readout.readout], name="output_0")(
            cost_cg)
    bottom_output = VariableFilter(
        # We need name_regex instead of name because LookupTable calls itsoutput output_0
        applications=[r.bottom.apply], name_regex="output")(
            cost_cg)[-1]
    attended, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended")(
            cost_cg)
    attended_mask, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended_mask")(
            cost_cg)
    weights, = VariableFilter(
        applications=[r.generator.evaluate], name="weights")(
            cost_cg)
    max_recording_length = rename(bottom_output.shape[0],
                                  "max_recording_length")
    # To exclude subsampling related bugs
    max_attended_mask_length = rename(attended_mask.shape[0],
                                      "max_attended_mask_length")
    max_attended_length = rename(attended.shape[0],
                                 "max_attended_length")
    max_num_phonemes = rename(labels.shape[0],
                              "max_num_phonemes")
    min_energy = rename(energies.min(), "min_energy")
    max_energy = rename(energies.max(), "max_energy")
    mean_attended = rename(abs(attended).mean(),
                           "mean_attended")
    mean_bottom_output = rename(abs(bottom_output).mean(),
                                "mean_bottom_output")
    weights_penalty = rename(monotonicity_penalty(weights, labels_mask),
                             "weights_penalty")
    weights_entropy = rename(entropy(weights, labels_mask),
                             "weights_entropy")
    mask_density = rename(labels_mask.mean(),
                          "mask_density")
    cg = ComputationGraph([
        cost, weights_penalty, weights_entropy,
        min_energy, max_energy,
        mean_attended, mean_bottom_output,
        batch_size, max_num_phonemes,
        mask_density])
    # Regularization. It is applied explicitly to all variables
    # of interest, it could not be applied to the cost only as it
    # would not have effect on auxiliary variables, see Blocks #514.
    reg_config = config.get('regularization', dict())
    regularized_cg = cg
    if reg_config.get('dropout'):
        logger.info('apply dropout')
        regularized_cg = apply_dropout(cg, [bottom_output], 0.5)
    if reg_config.get('noise'):
        logger.info('apply noise')
        noise_subjects = [p for p in cg.parameters if p not in attention_params]
        regularized_cg = apply_noise(cg, noise_subjects, reg_config['noise'])

    train_cost = regularized_cg.outputs[0]
    if reg_config.get("penalty_coof", .0) > 0:
        # big warning!!!
        # here we assume that:
        # regularized_weights_penalty = regularized_cg.outputs[1]
        train_cost = (train_cost +
                      reg_config.get("penalty_coof", .0) *
                      regularized_cg.outputs[1] / batch_size)
    if reg_config.get("decay", .0) > 0:
        train_cost = (train_cost + reg_config.get("decay", .0) *
                      l2_norm(VariableFilter(roles=[WEIGHT])(cg.parameters)) ** 2)

    train_cost = rename(train_cost, 'train_cost')

    gradients = None
    if reg_config.get('adaptive_noise'):
        logger.info('apply adaptive noise')
        if ((reg_config.get("penalty_coof", .0) > 0) or
                (reg_config.get("decay", .0) > 0)):
            logger.error('using  adaptive noise with alignment weight panalty '
                         'or weight decay is probably stupid')
        train_cost, regularized_cg, gradients, noise_brick = apply_adaptive_noise(
            cg, cg.outputs[0],
            variables=cg.parameters,
            num_examples=data.get_dataset('train').num_examples,
            parameters=Model(regularized_cg.outputs[0]).get_parameter_dict().values(),
            **reg_config.get('adaptive_noise')
        )
        train_cost.name = 'train_cost'
        adapt_noise_cg = ComputationGraph(train_cost)
        model_prior_mean = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_mean')(adapt_noise_cg)[0],
            'model_prior_mean')
        model_cost = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_cost')(adapt_noise_cg)[0],
            'model_cost')
        model_prior_variance = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_variance')(adapt_noise_cg)[0],
            'model_prior_variance')
        regularized_cg = ComputationGraph(
            [train_cost, model_cost] +
            regularized_cg.outputs +
            [model_prior_mean, model_prior_variance])
        primary_observables += [
            regularized_cg.outputs[1],  # model cost
            regularized_cg.outputs[2],  # task cost
            regularized_cg.outputs[-2],  # model prior mean
            regularized_cg.outputs[-1]]  # model prior variance

    model = Model(train_cost)
    if params:
        logger.info("Load parameters from " + params)
        # please note: we cannot use recognizer.load_params
        # as it builds a new computation graph that dies not have
        # shapred variables added by adaptive weight noise
        with open(params, 'r') as src:
            param_values = load_parameters(src)
        model.set_parameter_values(param_values)

    parameters = model.get_parameter_dict()
    logger.info("Parameters:\n" +
                pprint.pformat(
                    [(key, parameters[key].get_value().shape) for key
                     in sorted(parameters.keys())],
                    width=120))

    # Define the training algorithm.
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'], train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    max_norm_rules = []
    if reg_config.get('max_norm', False) > 0:
        logger.info("Apply MaxNorm")
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(cg.parameters)
        if reg_config.get('max_norm_exclude_lookup', False):
            maxnorm_subjects = [v for v in maxnorm_subjects
                                if not isinstance(get_brick(v), LookupTable)]
        logger.info("Parameters covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in parameters.items()
                                      if p in maxnorm_subjects]))
        logger.info("Parameters NOT covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in parameters.items()
                                      if not p in maxnorm_subjects]))
        max_norm_rules = [
            Restrict(VariableClipping(reg_config['max_norm'], axis=0),
                     maxnorm_subjects)]
    burn_in = []
    if train_conf.get('burn_in_steps', 0):
        burn_in.append(
            BurnIn(num_steps=train_conf['burn_in_steps']))
    algorithm = GradientDescent(
        cost=train_cost,
        parameters=parameters.values(),
        gradients=gradients,
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)] + burn_in),
        on_unused_sources='warn')

    logger.debug("Scan Ops in the gradients")
    gradient_cg = ComputationGraph(algorithm.gradients.values())
    for op in ComputationGraph(gradient_cg).scans:
        logger.debug(op)

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    secondary_observables += list(regularized_cg.outputs)
    if not 'train_cost' in [v.name for v in secondary_observables]:
        secondary_observables += [train_cost]
    secondary_observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold]
    for name, param in parameters.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements ** 0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements ** 0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements ** 0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        secondary_observables.append(stats)

    primary_observables += [
        train_cost,
        algorithm.total_gradient_norm,
        algorithm.total_step_norm, clipping.threshold,
        max_recording_length,
        max_attended_length, max_attended_mask_length]

    validation_observables += [
        rename(aggregation.mean(batch_cost, batch_size), cost.name),
        rename(aggregation.sum_(batch_size), 'num_utterances'),
        weights_entropy, weights_penalty]

    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name == 'weights_penalty':
                result.append(rename(aggregation.mean(var, batch_size),
                                     'weights_penalty_per_recording'))
            elif var.name == 'weights_entropy':
                result.append(rename(aggregation.mean(var, labels_mask.sum()),
                                     'weights_entropy_per_label'))
            else:
                result.append(var)
        return result

    mon_conf = config['monitoring']

    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
    ]
    extensions.append(TrainingDataMonitoring(
        primary_observables, after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(secondary_observables),
        prefix="average", every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes(validation_observables),
        data.get_stream("valid", shuffle=False), prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['validate_every_epochs'],
            every_n_batches=mon_conf['validate_every_batches'],
            after_training=False)
    extensions.append(validation)
    per = PhonemeErrorRate(recognizer, data,
                           **config['monitoring']['search'])
    per_monitoring = DataStreamMonitoring(
        [per], data.get_stream("valid", batches=False, shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['search_every_epochs'],
            every_n_batches=mon_conf['search_every_batches'],
            after_training=False)
    extensions.append(per_monitoring)
    track_the_best_per = TrackTheBest(
        per_monitoring.record_name(per)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_cost = TrackTheBest(
        validation.record_name(cost)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    extensions += [track_the_best_cost, track_the_best_per]
    extensions.append(AdaptiveClipping(
        algorithm.total_gradient_norm.name,
        clipping, train_conf['gradient_threshold'],
        decay_rate=0.998, burnin_period=500))
    extensions += [
        SwitchOffLengthFilter(
            data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf.get('num_batches'),
                    after_n_epochs=train_conf.get('num_epochs'))
            .add_condition(["after_batch"], _gradient_norm_is_none),
    ]
    channels = [
        # Plot 1: training and validation costs
        [average_monitoring.record_name(train_cost),
         validation.record_name(cost)],
        # Plot 2: gradient norm,
        [average_monitoring.record_name(algorithm.total_gradient_norm),
         average_monitoring.record_name(clipping.threshold)],
        # Plot 3: phoneme error rate
        [per_monitoring.record_name(per)],
        # Plot 4: training and validation mean weight entropy
        [average_monitoring._record_name('weights_entropy_per_label'),
         validation._record_name('weights_entropy_per_label')],
        # Plot 5: training and validation monotonicity penalty
        [average_monitoring._record_name('weights_penalty_per_recording'),
         validation._record_name('weights_penalty_per_recording')]]
    if bokeh:
        extensions += [
            Plot(bokeh_name if bokeh_name
                 else os.path.basename(save_path),
                 channels,
                 every_n_batches=10,
                 server_url=bokeh_server),]
    extensions += [
        Checkpoint(save_path,
                   before_first_epoch=not fast_start, after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True)
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_per.notification_name),
            (root_path + "_best" + extension,))
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_cost.notification_name),
            (root_path + "_best_ll" + extension,)),
        ProgressBar()]
    extensions.append(EmbedIPython(use_main_loop_run_caller_env=True))
    if config['net']['criterion']['name'].startswith('mse'):
        extensions.append(
            LogInputsGains(
                labels, cg, recognizer.generator.readout.emitter, data))

    if train_conf.get('patience'):
        patience_conf = train_conf['patience']
        if not patience_conf.get('notification_names'):
            # setdefault will not work for empty list
            patience_conf['notification_names'] = [
                track_the_best_per.notification_name,
                track_the_best_cost.notification_name]
        extensions.append(Patience(**patience_conf))

    extensions.append(Printing(every_n_batches=1,
                               attribute_filter=PrintingFilterList()))

    return model, algorithm, data, extensions