Esempio n. 1
0
def get_sampling_model_and_input(exp_config):
    # Create Theano variables
    encoder = BidirectionalEncoder(
        exp_config['src_vocab_size'], exp_config['enc_embed'], exp_config['enc_nhids'])

    decoder = Decoder(
        exp_config['trg_vocab_size'], exp_config['dec_embed'], exp_config['dec_nhids'],
        exp_config['enc_nhids'] * 2,
        loss_function='min_risk'
    )

    # Create Theano variables
    logger.info('Creating theano variables')
    sampling_source_input = tensor.lmatrix('source')
    sampling_target_prefix_input = tensor.lmatrix('target')

    # Get beam search
    logger.info("Building sampling model")
    sampling_representation = encoder.apply(
        sampling_source_input, tensor.ones(sampling_source_input.shape))

    generated = decoder.generate(sampling_source_input, sampling_representation,
                                 target_prefix=sampling_target_prefix_input)

    # build the model that will let us get a theano function from the sampling graph
    logger.info("Creating Sampling Model...")
    sampling_model = Model(generated)

    # Set the parameters from a trained models
    logger.info("Loading parameters from model: {}".format(exp_config['saved_parameters']))
    # load the parameter values from an .npz file
    param_values = LoadNMT.load_parameter_values(exp_config['saved_parameters'], brick_delimiter='-')
    LoadNMT.set_model_parameters(sampling_model, param_values)

    return sampling_model, sampling_source_input, encoder, decoder
Esempio n. 2
0
def load_params_and_get_beam_search(exp_config):

    encoder = BidirectionalEncoder(exp_config['src_vocab_size'],
                                   exp_config['enc_embed'],
                                   exp_config['enc_nhids'])

    # let user specify the target transition class name in config,
    # eval it and pass to decoder
    target_transition_name = exp_config.get(
        'target_transition', 'GRUInitialStateWithInitialStateSumContext')
    target_transition = eval(target_transition_name)

    decoder = InitialContextDecoder(exp_config['trg_vocab_size'],
                                    exp_config['dec_embed'],
                                    exp_config['dec_nhids'],
                                    exp_config['enc_nhids'] * 2,
                                    exp_config['context_dim'],
                                    target_transition)

    # Create Theano variables
    logger.info('Creating theano variables')
    sampling_input = tensor.lmatrix('source')
    sampling_context = tensor.matrix('context_input')

    logger.info("Building sampling model")
    sampling_representation = encoder.apply(sampling_input,
                                            tensor.ones(sampling_input.shape))

    generated = decoder.generate(sampling_input, sampling_representation,
                                 sampling_context)
    _, samples = VariableFilter(
        bricks=[decoder.sequence_generator], name="outputs")(ComputationGraph(
            generated[1]))  # generated[1] is next_outputs

    beam_search = BeamSearch(samples=samples)

    # Set the parameters
    logger.info("Creating Model...")
    model = Model(generated)
    logger.info("Loading parameters from model: {}".format(
        exp_config['saved_parameters']))

    # load the parameter values from an .npz file
    param_values = LoadNMT.load_parameter_values(
        exp_config['saved_parameters'])
    LoadNMT.set_model_parameters(model, param_values)

    return beam_search, sampling_input, sampling_context
Esempio n. 3
0
def get_prediction_function(exp_config):

    # Create Theano variables
    logger.info('Creating theano variables')
    source_sentence = tensor.lmatrix('source')
    source_sentence_mask = tensor.matrix('source_mask')
    target_sentence = tensor.lmatrix('target_suffix')
    target_sentence_mask = tensor.matrix('target_suffix_mask')
    target_prefix = tensor.lmatrix('target_prefix')
    target_prefix_mask = tensor.matrix('target_prefix_mask')

    # build the model
    encoder = BidirectionalEncoder(exp_config['src_vocab_size'],
                                   exp_config['enc_embed'],
                                   exp_config['enc_nhids'])

    # Note: the 'min_risk' kwarg tells the decoder which sequence_generator and cost_function to use
    decoder = NMTPrefixDecoder(exp_config['trg_vocab_size'],
                               exp_config['dec_embed'],
                               exp_config['dec_nhids'],
                               exp_config['enc_nhids'] * 2,
                               loss_function='cross_entropy')

    # rename to match baseline NMT systems
    decoder.name = 'decoder'

    prediction_tags = decoder.prediction_tags(
        encoder.apply(source_sentence, source_sentence_mask),
        source_sentence_mask, target_sentence, target_sentence_mask,
        target_prefix, target_prefix_mask)

    logger.info('Creating computational graph')

    prediction_model = Model(prediction_tags)

    # Note that the parameters of this model must be pretrained, otherwise this doesn't make sense
    param_values = LoadNMT.load_parameter_values(
        exp_config['saved_parameters'], brick_delimiter=None)
    LoadNMT.set_model_parameters(prediction_model, param_values)

    prediction_function = prediction_model.get_theano_function()

    return prediction_function
Esempio n. 4
0
def load_params_and_get_beam_search(exp_config):

    encoder = BidirectionalEncoder(exp_config['src_vocab_size'],
                                   exp_config['enc_embed'],
                                   exp_config['enc_nhids'])

    decoder = Decoder(exp_config['trg_vocab_size'], exp_config['dec_embed'],
                      exp_config['dec_nhids'], exp_config['enc_nhids'] * 2)

    # Create Theano variables
    logger.info('Creating theano variables')
    sampling_input = tensor.lmatrix('source')

    # Get beam search
    logger.info("Building sampling model")
    sampling_representation = encoder.apply(sampling_input,
                                            tensor.ones(sampling_input.shape))
    generated = decoder.generate(sampling_input, sampling_representation)

    _, samples = VariableFilter(
        bricks=[decoder.sequence_generator], name="outputs")(ComputationGraph(
            generated[1]))  # generated[1] is next_outputs
    beam_search = BeamSearch(samples=samples)

    # Set the parameters
    logger.info("Creating Model...")
    model = Model(generated)
    logger.info("Loading parameters from model: {}".format(
        exp_config['saved_parameters']))

    # load the parameter values from an .npz file if the `saved_parameters` field is present in the config
    param_values = LoadNMT.load_parameter_values(
        exp_config['saved_parameters'],
        brick_delimiter=exp_config.get('brick_delimiter', None))
    LoadNMT.set_model_parameters(model, param_values)

    return beam_search, sampling_input
Esempio n. 5
0
def main(exp_config, source_vocab, target_vocab, dev_stream, use_bokeh=True):

    # def setup_model_and_stream(exp_config, source_vocab, target_vocab):
    # def setup_model_and_stream(exp_config, source_vocab, target_vocab):
    train_encoder, train_decoder, theano_sampling_source_input, theano_sampling_context_input, generated, masked_stream = setup_model_and_stream(
        exp_config, source_vocab, target_vocab)
    cost = create_model(train_encoder, train_decoder,
                        exp_config.get('imt_smoothing_constant', 0.005))

    # Set up training model
    logger.info("Building model")
    train_model = Model(cost)

    # Set the parameters from a trained models (.npz file)
    logger.info("Loading parameters from model: {}".format(
        exp_config['saved_parameters']))
    # Note the brick delimeter='-' is here for legacy reasons because blocks changed the serialization API
    param_values = LoadNMT.load_parameter_values(
        exp_config['saved_parameters'],
        brick_delimiter=exp_config.get('brick_delimiter', None))
    LoadNMT.set_model_parameters(train_model, param_values)

    logger.info('Creating computational graph')
    cg = ComputationGraph(cost)

    # GRAPH TRANSFORMATIONS FOR BETTER TRAINING
    if exp_config.get('l2_regularization', False) is True:
        l2_reg_alpha = exp_config['l2_regularization_alpha']
        logger.info(
            'Applying l2 regularization with alpha={}'.format(l2_reg_alpha))
        model_weights = VariableFilter(roles=[WEIGHT])(cg.variables)

        for W in model_weights:
            cost = cost + (l2_reg_alpha * (W**2).sum())

        # why do we need to rename the cost variable? Where did the original name come from?
        cost.name = 'decoder_cost_cost'

    cg = ComputationGraph(cost)

    # apply dropout for regularization
    # Note dropout variables are hard-coded here
    if exp_config['dropout'] < 1.0:
        # dropout is applied to the output of maxout in ghog
        # this is the probability of dropping out, so you probably want to make it <=0.5
        logger.info('Applying dropout')
        dropout_inputs = [
            x for x in cg.intermediary_variables
            if x.name == 'maxout_apply_output'
        ]
        cg = apply_dropout(cg, dropout_inputs, exp_config['dropout'])

    # create the training directory, and copy this config there if directory doesn't exist
    if not os.path.isdir(exp_config['saveto']):
        os.makedirs(exp_config['saveto'])
        # TODO: mv the actual config file once we switch to .yaml for min-risk
        shutil.copy(exp_config['config_file'], exp_config['saveto'])

    # Set extensions
    logger.info("Initializing extensions")
    extensions = [
        FinishAfter(after_n_batches=exp_config['finish_after']),
        TrainingDataMonitoring([cost], after_batch=True),
        Printing(after_batch=True),
        CheckpointNMT(exp_config['saveto'],
                      every_n_batches=exp_config['save_freq'])
    ]

    # Set up beam search and sampling computation graphs if necessary
    # TODO: change the if statement here
    if exp_config['hook_samples'] >= 1 or exp_config['bleu_script'] is not None:
        logger.info("Building sampling model")
        search_model = Model(generated)
        _, samples = VariableFilter(
            bricks=[train_decoder.sequence_generator], name="outputs")(
                ComputationGraph(generated[1]))  # generated[1] is next_outputs

    # Add sampling -- TODO: sampling is broken for min-risk
    #if config['hook_samples'] >= 1:
    #    logger.info("Building sampler")
    #    extensions.append(
    #        Sampler(model=search_model, data_stream=tr_stream,
    #                hook_samples=config['hook_samples'],
    #                every_n_batches=config['sampling_freq'],
    #                src_vocab_size=config['src_vocab_size']))

    # Add early stopping based on bleu
    # TODO: use multimodal meteor and BLEU validator
    # TODO: add 'validator' key to IMT config
    # Add early stopping based on bleu
    if exp_config.get('bleu_script', None) is not None:
        logger.info("Building bleu validator")
        extensions.append(
            BleuValidator(theano_sampling_source_input,
                          theano_sampling_context_input,
                          samples=samples,
                          config=exp_config,
                          model=search_model,
                          data_stream=dev_stream,
                          src_vocab=source_vocab,
                          trg_vocab=target_vocab,
                          normalize=exp_config['normalized_bleu'],
                          every_n_batches=exp_config['bleu_val_freq']))

    if exp_config.get('imt_f1_validation', False) is not False:
        logger.info("Building imt F1 validator")
        extensions.append(
            IMT_F1_Validator(theano_sampling_source_input,
                             theano_sampling_context_input,
                             samples=samples,
                             config=exp_config,
                             model=search_model,
                             data_stream=dev_stream,
                             src_vocab=source_vocab,
                             trg_vocab=target_vocab,
                             normalize=exp_config['normalized_bleu'],
                             every_n_batches=exp_config['bleu_val_freq']))

    # Add early stopping based on Meteor
    # if exp_config.get('meteor_directory', None) is not None:
    #     logger.info("Building meteor validator")
    #     extensions.append(
    #         MeteorValidator(theano_sampling_source_input, theano_sampling_context_input,
    #                         samples=samples,
    #                         config=config,
    #                         model=search_model, data_stream=dev_stream,
    #                         src_vocab=src_vocab,
    #                         trg_vocab=trg_vocab,
    #                         normalize=config['normalized_bleu'],
    #                         every_n_batches=config['bleu_val_freq']))

    # Reload model if necessary
    if exp_config['reload']:
        extensions.append(LoadNMT(exp_config['saveto']))

    # Plot cost in bokeh if necessary
    if use_bokeh and BOKEH_AVAILABLE:
        extensions.append(
            Plot(exp_config['model_save_directory'],
                 channels=[[
                     'decoder_cost_cost', 'validation_set_imt_f1_score',
                     'validation_set_bleu_score', 'validation_set_meteor_score'
                 ]],
                 every_n_batches=10))

    # Set up training algorithm
    logger.info("Initializing training algorithm")

    # if there is l2_regularization, dropout or random noise, we need to use the output of the modified graph
    # WORKING: try to catch and fix nan
    if exp_config['dropout'] < 1.0:
        if exp_config.get('nan_guard', False):
            from theano.compile.nanguardmode import NanGuardMode
            algorithm = GradientDescent(cost=cg.outputs[0],
                                        parameters=cg.parameters,
                                        step_rule=CompositeRule([
                                            StepClipping(
                                                exp_config['step_clipping']),
                                            eval(exp_config['step_rule'])()
                                        ]),
                                        on_unused_sources='warn',
                                        theano_func_kwargs={
                                            'mode':
                                            NanGuardMode(nan_is_error=True,
                                                         inf_is_error=True)
                                        })
        else:
            algorithm = GradientDescent(cost=cg.outputs[0],
                                        parameters=cg.parameters,
                                        step_rule=CompositeRule([
                                            StepClipping(
                                                exp_config['step_clipping']),
                                            eval(exp_config['step_rule'])()
                                        ]),
                                        on_unused_sources='warn')
    else:
        algorithm = GradientDescent(cost=cost,
                                    parameters=cg.parameters,
                                    step_rule=CompositeRule([
                                        StepClipping(
                                            exp_config['step_clipping']),
                                        eval(exp_config['step_rule'])()
                                    ]),
                                    on_unused_sources='warn')

    # enrich the logged information
    extensions.append(Timing(every_n_batches=100))

    # Initialize main loop
    logger.info("Initializing main loop")
    main_loop = MainLoop(model=train_model,
                         algorithm=algorithm,
                         data_stream=masked_stream,
                         extensions=extensions)

    # Train!
    main_loop.run()
Esempio n. 6
0
def main(model, cost, config, tr_stream, dev_stream, use_bokeh=False):

    # Set the parameters from a trained models (.npz file)
    logger.info("Loading parameters from model: {}".format(
        exp_config['saved_parameters']))
    # Note the brick delimeter='-' is here for legacy reasons because blocks changed the serialization API
    param_values = LoadNMT.load_parameter_values(
        exp_config['saved_parameters'],
        brick_delimiter=exp_config.get('brick_delimiter', None))
    LoadNMT.set_model_parameters(model, param_values)

    logger.info('Creating computational graph')
    cg = ComputationGraph(cost)

    # GRAPH TRANSFORMATIONS FOR BETTER TRAINING
    if config.get('l2_regularization', False) is True:
        l2_reg_alpha = config['l2_regularization_alpha']
        logger.info(
            'Applying l2 regularization with alpha={}'.format(l2_reg_alpha))
        model_weights = VariableFilter(roles=[WEIGHT])(cg.variables)

        for W in model_weights:
            cost = cost + (l2_reg_alpha * (W**2).sum())

        # why do we need to rename the cost variable? Where did the original name come from?
        cost.name = 'decoder_cost_cost'

    cg = ComputationGraph(cost)

    # apply dropout for regularization
    # Note dropout variables are hard-coded here
    if config['dropout'] < 1.0:
        # dropout is applied to the output of maxout in ghog
        # this is the probability of dropping out, so you probably want to make it <=0.5
        logger.info('Applying dropout')
        dropout_inputs = [
            x for x in cg.intermediary_variables
            if x.name == 'maxout_apply_output'
        ]
        cg = apply_dropout(cg, dropout_inputs, config['dropout'])

    # create the training directory, and copy this config there if directory doesn't exist
    if not os.path.isdir(config['saveto']):
        os.makedirs(config['saveto'])
        # TODO: mv the actual config file once we switch to .yaml for min-risk
        # shutil.copy(config['config_file'], config['saveto'])
        # shutil.copy(config['config_file'], config['saveto'])

        # TODO: this breaks when we directly reference a class in the config obj instead of using reflection
        with codecs.open(os.path.join(config['saveto'], 'config.yaml'),
                         'w',
                         encoding='utf8') as yaml_out:
            yaml_out.write(yaml.dump(config))

    # Set extensions
    logger.info("Initializing extensions")
    extensions = [
        FinishAfter(after_n_batches=config['finish_after']),
        TrainingDataMonitoring([cost], after_batch=True),
        Printing(after_batch=True),
        CheckpointNMT(config['saveto'], every_n_batches=config['save_freq'])
    ]

    # Set up beam search and sampling computation graphs if necessary
    # TODO: change the if statement here
    if config['hook_samples'] >= 1 or config['bleu_script'] is not None:
        logger.info("Building sampling model")
        sampling_representation = train_encoder.apply(
            theano_sampling_source_input,
            tensor.ones(theano_sampling_source_input.shape))
        # TODO: the generated output actually contains several more values, ipdb to see what they are
        generated = train_decoder.generate(theano_sampling_source_input,
                                           sampling_representation,
                                           theano_sampling_context_input)
        search_model = Model(generated)
        _, samples = VariableFilter(
            bricks=[train_decoder.sequence_generator], name="outputs")(
                ComputationGraph(generated[1]))  # generated[1] is next_outputs

    # Add sampling -- TODO: sampling is broken for min-risk
    #if config['hook_samples'] >= 1:
    #    logger.info("Building sampler")
    #    extensions.append(
    #        Sampler(model=search_model, data_stream=tr_stream,
    #                hook_samples=config['hook_samples'],
    #                every_n_batches=config['sampling_freq'],
    #                src_vocab_size=config['src_vocab_size']))

    # Add early stopping based on bleu
    # TODO: use multimodal meteor and BLEU validator
    # Add early stopping based on bleu
    if config.get('bleu_script', None) is not None:
        logger.info("Building bleu validator")
        extensions.append(
            BleuValidator(theano_sampling_source_input,
                          theano_sampling_context_input,
                          samples=samples,
                          config=config,
                          model=search_model,
                          data_stream=dev_stream,
                          src_vocab=src_vocab,
                          trg_vocab=trg_vocab,
                          normalize=config['normalized_bleu'],
                          every_n_batches=config['bleu_val_freq']))

    # Add early stopping based on Meteor
    if config.get('meteor_directory', None) is not None:
        logger.info("Building meteor validator")
        extensions.append(
            MeteorValidator(theano_sampling_source_input,
                            theano_sampling_context_input,
                            samples=samples,
                            config=config,
                            model=search_model,
                            data_stream=dev_stream,
                            src_vocab=src_vocab,
                            trg_vocab=trg_vocab,
                            normalize=config['normalized_bleu'],
                            every_n_batches=config['bleu_val_freq']))

    # Reload model if necessary
    if config['reload']:
        extensions.append(LoadNMT(config['saveto']))

    # Plot cost in bokeh if necessary
    if use_bokeh and BOKEH_AVAILABLE:
        extensions.append(
            Plot(config['model_save_directory'],
                 channels=[[
                     'decoder_cost_cost', 'validation_set_bleu_score',
                     'validation_set_meteor_score'
                 ]],
                 every_n_batches=10))

    # Set up training algorithm
    logger.info("Initializing training algorithm")

    # if there is l2_regularization, dropout or random noise, we need to use the output of the modified graph
    if config['dropout'] < 1.0:
        algorithm = GradientDescent(cost=cg.outputs[0],
                                    parameters=cg.parameters,
                                    step_rule=CompositeRule([
                                        StepClipping(config['step_clipping']),
                                        eval(config['step_rule'])()
                                    ]),
                                    on_unused_sources='warn')
    else:
        algorithm = GradientDescent(cost=cost,
                                    parameters=cg.parameters,
                                    step_rule=CompositeRule([
                                        StepClipping(config['step_clipping']),
                                        eval(config['step_rule'])()
                                    ]),
                                    on_unused_sources='warn')

    #algorithm = GradientDescent(
    #    cost=cost, parameters=cg.parameters,
    #    step_rule=CompositeRule([StepClipping(config['step_clipping']),
    #                             eval(config['step_rule'])()],
    #                           ),
    #    on_unused_sources='warn'
    #)

    # enrich the logged information
    extensions.append(Timing(every_n_batches=100))

    # Initialize main loop
    logger.info("Initializing main loop")
    main_loop = MainLoop(model=model,
                         algorithm=algorithm,
                         data_stream=tr_stream,
                         extensions=extensions)

    # Train!
    main_loop.run()
Esempio n. 7
0
def load_params_and_get_beam_search(exp_config,
                                    decoder=None,
                                    encoder=None,
                                    brick_delimiter=None):

    if encoder is None:
        encoder = BidirectionalEncoder(exp_config['src_vocab_size'],
                                       exp_config['enc_embed'],
                                       exp_config['enc_nhids'])

    # Note: decoder should be None when we are just doing prediction, not validation
    if decoder is None:
        decoder = NMTPrefixDecoder(exp_config['trg_vocab_size'],
                                   exp_config['dec_embed'],
                                   exp_config['dec_nhids'],
                                   exp_config['enc_nhids'] * 2,
                                   loss_function='cross_entropy')
        # rename to match baseline NMT systems so that params can be transparently initialized
        decoder.name = 'decoder'

    # Create Theano variables
    logger.info('Creating theano variables')
    sampling_input = tensor.lmatrix('sampling_input')
    sampling_prefix = tensor.lmatrix('sampling_target_prefix')

    # Get beam search
    logger.info("Building sampling model")
    sampling_representation = encoder.apply(sampling_input,
                                            tensor.ones(sampling_input.shape))

    # Note: prefix can be empty if we want to simulate baseline NMT
    n_steps = exp_config.get('n_steps', None)
    generated = decoder.generate(sampling_input,
                                 sampling_representation,
                                 target_prefix=sampling_prefix,
                                 n_steps=n_steps)

    # create the 1-step sampling graph
    _, samples = VariableFilter(
        bricks=[decoder.sequence_generator], name="outputs")(ComputationGraph(
            generated[1]))  # generated[1] is next_outputs

    # set up beam search
    beam_search = BeamSearch(samples=samples)

    logger.info("Creating Search Model...")
    search_model = Model(generated)

    # optionally set beam search model parameter values from an .npz file
    # Note: we generally would set the model params in this way when doing only prediction/evaluation
    # Go ahead and initialize to some random values -- this is because the confidence model params below are optional
    if not hasattr(encoder, 'initialized'):
        encoder.push_initialization_config()
        encoder.initialize()
        encoder.bidir.prototype.weights_init = Orthogonal()
    if not hasattr(decoder, 'initialized'):
        decoder.push_initialization_config()
        decoder.transition.weights_init = Orthogonal()
        decoder.initialize()

    if exp_config.get('load_from_saved_parameters', False):
        logger.info("Loading parameters from model: {}".format(
            exp_config['saved_parameters']))
        param_values = LoadNMT.load_parameter_values(
            exp_config['saved_parameters'], brick_delimiter=brick_delimiter)
        LoadNMT.set_model_parameters(search_model, param_values)
        # TODO: CONFIDENCE PREDICTION SHOULD BE OPTIONAL -- RIGHT NOW IT'S HARD-CODED INTO BEAM SEARCH
        if exp_config.get('confidence_saved_parameters', False):
            param_values = LoadNMT.load_parameter_values(
                exp_config['confidence_saved_parameters'],
                brick_delimiter=brick_delimiter)
            LoadNMT.set_model_parameters(search_model, param_values)

    return beam_search, search_model, samples, sampling_input, sampling_prefix
Esempio n. 8
0
def get_confidence_function(exp_config):

    # Create Theano variables
    logger.info('Creating theano variables')
    source_sentence = tensor.lmatrix('source')
    source_sentence_mask = tensor.matrix('source_mask')
    target_sentence = tensor.lmatrix('target_suffix')
    target_sentence_mask = tensor.matrix('target_suffix_mask')
    target_prefix = tensor.lmatrix('target_prefix')
    target_prefix_mask = tensor.matrix('target_prefix_mask')

    logger.info('Creating computational graph')

    # build the model
    encoder = BidirectionalEncoder(exp_config['src_vocab_size'],
                                   exp_config['enc_embed'],
                                   exp_config['enc_nhids'])

    # Note: the 'min_risk' kwarg tells the decoder which sequence_generator and cost_function to use
    decoder = NMTPrefixDecoder(exp_config['trg_vocab_size'],
                               exp_config['dec_embed'],
                               exp_config['dec_nhids'],
                               exp_config['enc_nhids'] * 2,
                               loss_function='cross_entropy')

    # rename to match baseline NMT systems
    decoder.name = 'decoder'

    predictions, merged_states = decoder.prediction_tags(
        encoder.apply(source_sentence, source_sentence_mask),
        source_sentence_mask, target_sentence, target_sentence_mask,
        target_prefix, target_prefix_mask)

    # WORKING: also get the softmax prediction feature
    # WORKING: add features for source len, prefix len, position in suffix (position in suffix only makes sense if we're training on predictions)
    p_shape = predictions.shape
    predictions = predictions.reshape([p_shape[0] * p_shape[1], p_shape[2]])
    prediction_softmax = tensor.nnet.nnet.softmax(
        predictions.reshape([p_shape[0] * p_shape[1],
                             p_shape[2]])).reshape(p_shape)
    prediction_feature = prediction_softmax.max(axis=-1)[:, :, None]
    all_features = tensor.concatenate([merged_states, prediction_feature],
                                      axis=-1)

    confidence_output = decoder.sequence_generator.confidence_predictions(
        all_features)

    logger.info('Creating computational graph')
    confidence_model = Model(confidence_output)

    # Note that the parameters of this model must be pretrained, otherwise this doesn't make sense
    param_values = LoadNMT.load_parameter_values(
        exp_config['saved_parameters'], brick_delimiter=None)
    LoadNMT.set_model_parameters(confidence_model, param_values)

    confidence_param_values = LoadNMT.load_parameter_values(
        exp_config['confidence_saved_parameters'], brick_delimiter=None)
    LoadNMT.set_model_parameters(confidence_model, confidence_param_values)

    confidence_function = confidence_model.get_theano_function()

    return confidence_function