コード例 #1
0
def train_language_model(new_training_job, config, save_path, params,
                         fast_start, fuel_server, seed):
    c = config
    if seed:
        fuel.config.default_seed = seed
        blocks.config.config.default_seed = seed

    data, lm, retrieval = initialize_data_and_model(config)

    # full main loop can be saved...
    main_loop_path = os.path.join(save_path, 'main_loop.tar')
    # or only state (log + params) which can be useful not to pickle embeddings
    state_path = os.path.join(save_path, 'training_state.tar')
    stream_path = os.path.join(save_path, 'stream.pkl')
    best_tar_path = os.path.join(save_path, "best_model.tar")

    words = tensor.ltensor3('words')
    words_mask = tensor.matrix('words_mask')
    if theano.config.compute_test_value != 'off':
        test_value_data = next(
            data.get_stream('train', batch_size=4,
                            max_length=5).get_epoch_iterator())
        words.tag.test_value = test_value_data[0]
        words_mask.tag.test_value = test_value_data[1]

    costs, updates = lm.apply(words, words_mask)
    cost = rename(costs.mean(), 'mean_cost')

    cg = Model(cost)
    if params:
        logger.debug("Load parameters from {}".format(params))
        with open(params) as src:
            cg.set_parameter_values(load_parameters(src))

    length = rename(words.shape[1], 'length')
    perplexity, = VariableFilter(name='perplexity')(cg)
    perplexities = VariableFilter(name_regex='perplexity.*')(cg)
    monitored_vars = [length, cost] + perplexities
    if c['dict_path']:
        num_definitions, = VariableFilter(name='num_definitions')(cg)
        monitored_vars.extend([num_definitions])

    parameters = cg.get_parameter_dict()
    trained_parameters = parameters.values()
    saved_parameters = parameters.values()
    if c['embedding_path']:
        logger.debug("Exclude word embeddings from the trained parameters")
        trained_parameters = [
            p for p in trained_parameters
            if not p == lm.get_def_embeddings_params()
        ]
        saved_parameters = [
            p for p in saved_parameters
            if not p == lm.get_def_embeddings_params()
        ]

    if c['cache_size'] != 0:
        logger.debug("Enable fake recursivity for looking up embeddings")
        trained_parameters = [
            p for p in trained_parameters if not p == lm.get_cache_params()
        ]

    logger.info("Cost parameters" + "\n" + pprint.pformat([
        " ".join(
            (key, str(parameters[key].get_value().shape),
             'trained' if parameters[key] in trained_parameters else 'frozen'))
        for key in sorted(parameters.keys())
    ],
                                                          width=120))

    rules = []
    if c['grad_clip_threshold']:
        rules.append(StepClipping(c['grad_clip_threshold']))
    rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum']))
    algorithm = GradientDescent(cost=cost,
                                parameters=trained_parameters,
                                step_rule=CompositeRule(rules))

    if c['cache_size'] != 0:
        algorithm.add_updates(updates)

    train_monitored_vars = list(monitored_vars)
    if c['grad_clip_threshold']:
        train_monitored_vars.append(algorithm.total_gradient_norm)

    word_emb_RMS, = VariableFilter(name='word_emb_RMS')(cg)
    main_rnn_in_RMS, = VariableFilter(name='main_rnn_in_RMS')(cg)
    train_monitored_vars.extend([word_emb_RMS, main_rnn_in_RMS])

    if c['monitor_parameters']:
        train_monitored_vars.extend(parameter_stats(parameters, algorithm))

    # We use a completely random seed on purpose. With Fuel server
    # it's currently not possible to restore the state of the training
    # stream. That's why it's probably better to just have it stateless.
    stream_seed = numpy.random.randint(0, 10000000) if fuel_server else None
    training_stream = data.get_stream('train',
                                      batch_size=c['batch_size'],
                                      max_length=c['max_length'],
                                      seed=stream_seed)
    valid_stream = data.get_stream('valid',
                                   batch_size=c['batch_size_valid'],
                                   max_length=c['max_length'],
                                   seed=stream_seed)
    original_training_stream = training_stream
    if fuel_server:
        # the port will be configured by the StartFuelServer extension
        training_stream = ServerDataStream(
            sources=training_stream.sources,
            produces_examples=training_stream.produces_examples)

    validation = DataStreamMonitoring(monitored_vars,
                                      valid_stream,
                                      prefix="valid").set_conditions(
                                          before_first_epoch=not fast_start,
                                          on_resumption=True,
                                          every_n_batches=c['mon_freq_valid'])
    track_the_best = TrackTheBest(validation.record_name(perplexity),
                                  choose_best=min).set_conditions(
                                      on_resumption=True,
                                      after_epoch=True,
                                      every_n_batches=c['mon_freq_valid'])

    # don't save them the entire main loop to avoid pickling everything
    if c['fast_checkpoint']:
        load = (LoadNoUnpickling(state_path,
                                 load_iteration_state=True,
                                 load_log=True).set_conditions(
                                     before_training=not new_training_job))
        cp_args = {
            'save_main_loop': False,
            'save_separately': ['log', 'iteration_state'],
            'parameters': saved_parameters
        }

        checkpoint = Checkpoint(state_path,
                                before_training=not fast_start,
                                every_n_batches=c['save_freq_batches'],
                                after_training=not fast_start,
                                **cp_args)

        if c['checkpoint_every_n_batches']:
            intermediate_cp = IntermediateCheckpoint(
                state_path,
                every_n_batches=c['checkpoint_every_n_batches'],
                after_training=False,
                **cp_args)
    else:
        load = (Load(main_loop_path, load_iteration_state=True,
                     load_log=True).set_conditions(
                         before_training=not new_training_job))
        cp_args = {
            'save_separately': ['iteration_state'],
            'parameters': saved_parameters
        }

        checkpoint = Checkpoint(main_loop_path,
                                before_training=not fast_start,
                                every_n_batches=c['save_freq_batches'],
                                after_training=not fast_start,
                                **cp_args)

        if c['checkpoint_every_n_batches']:
            intermediate_cp = IntermediateCheckpoint(
                main_loop_path,
                every_n_batches=c['checkpoint_every_n_batches'],
                after_training=False,
                **cp_args)

    checkpoint = checkpoint.add_condition(
        ['after_batch', 'after_epoch'],
        OnLogRecord(track_the_best.notification_name), (best_tar_path, ))

    extensions = [
        load,
        StartFuelServer(original_training_stream,
                        stream_path,
                        before_training=fuel_server),
        Timing(every_n_batches=c['mon_freq_train'])
    ]

    if retrieval:
        extensions.append(
            RetrievalPrintStats(retrieval=retrieval,
                                every_n_batches=c['mon_freq_train'],
                                before_training=not fast_start))

    extensions.extend([
        TrainingDataMonitoring(train_monitored_vars,
                               prefix="train",
                               every_n_batches=c['mon_freq_train']),
        validation, track_the_best, checkpoint
    ])
    if c['checkpoint_every_n_batches']:
        extensions.append(intermediate_cp)
    extensions.extend([
        DumpTensorflowSummaries(save_path,
                                every_n_batches=c['mon_freq_train'],
                                after_training=True),
        Printing(on_resumption=True, every_n_batches=c['mon_freq_train']),
        FinishIfNoImprovementAfter(track_the_best.notification_name,
                                   iterations=50 * c['mon_freq_valid'],
                                   every_n_batches=c['mon_freq_valid']),
        FinishAfter(after_n_batches=c['n_batches'])
    ])

    logger.info("monitored variables during training:" + "\n" +
                pprint.pformat(train_monitored_vars, width=120))
    logger.info("monitored variables during valid:" + "\n" +
                pprint.pformat(monitored_vars, width=120))

    main_loop = MainLoop(algorithm,
                         training_stream,
                         model=Model(cost),
                         extensions=extensions)

    main_loop.run()
コード例 #2
0
def initialaze_algorithm(config, save_path, bokeh_name, params, bokeh_server,
                         bokeh, use_load_ext, load_log, fast_start, 
                         recognizer, data, model, cg, regularized_cg,
                         cost, train_cost, parameters, 
                         max_norm_rules, observables,
                         batch_size, batch_cost, weights_entropy, 
                         labels_mask, labels,  gradients=None):
    primary_observables = observables
    secondary_observables = []
    validation_observables = []
    root_path, extension = os.path.splitext(save_path)
    train_conf = config['training']
    # Define the training algorithm.
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'], train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    if 'adam' in rule_names:
        assert len(rule_names) == 1
        logger.info("Using Adam for training")
        core_rules.append(
            Adam(learning_rate=train_conf.get('scale', 0.002),
                 beta1=train_conf.get('beta1', 0.1),
                 beta2=train_conf.get('beta2', 0.001),
                 epsilon=train_conf.get('epsilon', 1e-8),
                 decay_factor=train_conf.get('decay_rate', (1 - 1e-8))))
    burn_in = []
    if train_conf.get('burn_in_steps', 0):
        burn_in.append(
            BurnIn(num_steps=train_conf['burn_in_steps']))
    algorithm = GradientDescent(
        cost=train_cost,
        parameters=parameters.values(),
        gradients=gradients,
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)] + burn_in),
        on_unused_sources='warn')
        #theano_func_kwargs={'mode':NanGuardMode(nan_is_error=True)})

    logger.debug("Scan Ops in the gradients")
    gradient_cg = ComputationGraph(algorithm.gradients.values())
    for op in ComputationGraph(gradient_cg).scans:
        logger.debug(op)

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    secondary_observables += list(regularized_cg.outputs)
    if not 'train_cost' in [v.name for v in secondary_observables]:
        secondary_observables += [train_cost]
    secondary_observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold]
    for name, param in parameters.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements ** 0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements ** 0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements ** 0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        secondary_observables.append(stats)

    primary_observables += [
        train_cost,
        algorithm.total_gradient_norm,
        algorithm.total_step_norm, clipping.threshold]

    validation_observables += [
        rename(aggregation.mean(batch_cost, batch_size), cost.name),
        rename(aggregation.sum_(batch_size), 'num_utterances')] + weights_entropy


    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name.startswith('weights_entropy'):
                chld_id = recognizer.child_id_from_postfix(var.name)
                result.append(rename(aggregation.mean(var, labels_mask[chld_id].sum()),
                                     'weights_entropy_per_label'+
                                     recognizer.children[chld_id].names_postfix))
            elif var.name.endswith('_nll'):
                chld_id = recognizer.child_id_from_postfix(var.name)
                result.append(rename(aggregation.mean(var.sum(),
                                                      labels_mask[chld_id].sum()),
                                     var.name+'_per_label'))
            else:
                result.append(var)
        return result

    mon_conf = config['monitoring']
    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
    ]
    extensions.append(TrainingDataMonitoring(
        primary_observables, after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(secondary_observables),
        prefix="average", every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes(validation_observables),
        data.get_stream("valid", shuffle=False, **data_params_valid), prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['validate_every_epochs'],
            every_n_batches=mon_conf['validate_every_batches'],
            after_training=False)
    extensions.append(validation)

    additional_patience_notifiers = []
    uas = DependencyErrorRate(recognizer.children[0], data,
                              **config['monitoring']['search'])
    las = AuxiliaryErrorRates(uas, name='LAS')
    lab = AuxiliaryErrorRates(uas, name='LAB')
    per_monitoring = DataStreamMonitoring(
        [uas, las, lab], data.get_one_stream("valid", data.langs[0], batches=False, shuffle=False, **data_params_valid)[0],
        prefix="valid").set_conditions(
                before_first_epoch=not fast_start,
                every_n_epochs=mon_conf['search_every_epochs'],
                every_n_batches=mon_conf['search_every_batches'],
                after_training=False)
    extensions.append(per_monitoring)
    track_the_best_uas = TrackTheBest(
        per_monitoring.record_name(uas)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_las = TrackTheBest(
        per_monitoring.record_name(las)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_lab = TrackTheBest(
        per_monitoring.record_name(lab)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    extensions += [track_the_best_uas,
                   track_the_best_las,
                   track_the_best_lab,
                   ]
    per = uas
    track_the_best_per = track_the_best_uas
    additional_patience_notifiers = [track_the_best_lab,
                                     track_the_best_las]
    track_the_best_cost = TrackTheBest(
        validation.record_name(cost)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    extensions += [track_the_best_cost]
    extensions.append(AdaptiveClipping(
        algorithm.total_gradient_norm.name,
        clipping, train_conf['gradient_threshold'],
        decay_rate=0.998, burnin_period=500,
        num_stds=train_conf.get('clip_stds', 1.0)))
    extensions += [
        SwitchOffLengthFilter(
            data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf['num_batches'],
                    after_n_epochs=train_conf['num_epochs']),
            # .add_condition(["after_batch"], _gradient_norm_is_none),
    ]
    main_postfix = recognizer.children[0].names_postfix
    channels = [
        # Plot 1: training and validation costs
        [average_monitoring.record_name(train_cost),
         validation.record_name(cost)],
        # Plot 2: gradient norm,
        [average_monitoring.record_name(algorithm.total_gradient_norm),
         average_monitoring.record_name(clipping.threshold)],
        # Plot 3: phoneme error rate
        [per_monitoring.record_name(per)],
        # Plot 4: training and validation mean weight entropy
        [average_monitoring._record_name('weights_entropy_per_label'+main_postfix),
         validation._record_name('weights_entropy_per_label'+main_postfix)],
        # Plot 5: training and validation monotonicity penalty
        [average_monitoring._record_name('weights_penalty_per_recording'+main_postfix),
         validation._record_name('weights_penalty_per_recording'+main_postfix)]]
    if bokeh:
        extensions += [
            Plot(bokeh_name if bokeh_name
                 else os.path.basename(save_path),
                 channels,
                 every_n_batches=10,
                 server_url=bokeh_server),]
    extensions += [
        Checkpoint(save_path,
                   before_first_epoch=not fast_start, after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True)
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_per.notification_name),
            (root_path + "_best" + extension,))
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_cost.notification_name),
            (root_path + "_best_ll" + extension,)),
        ProgressBar()]
    extensions.append(EmbedIPython(use_main_loop_run_caller_env=True))

    if train_conf.get('patience'):
        patience_conf = train_conf['patience']
        if not patience_conf.get('notification_names'):
            # setdefault will not work for empty list
            patience_conf['notification_names'] = [
                track_the_best_per.notification_name,
                track_the_best_cost.notification_name] + additional_patience_notifiers
        extensions.append(Patience(**patience_conf))

    if train_conf.get('min_performance_stops'):
        extensions.append(EarlyTermination(
            param_name=track_the_best_per.best_name,
            min_performance_by_epoch=train_conf['min_performance_stops']))

    extensions.append(Printing(every_n_batches=1,
                               attribute_filter=PrintingFilterList()))

    return model, algorithm, data, extensions
コード例 #3
0
ファイル: main.py プロジェクト: dmitriy-serdyuk/twinnet-asr
def initialize_all(config, save_path, bokeh_name, params, bokeh_server, bokeh,
                   test_tag, use_load_ext, load_log, fast_start):
    root_path, extension = os.path.splitext(save_path)

    data = Data(**config['data'])
    train_conf = config['training']
    recognizer = create_model(config, data, test_tag)

    # Separate attention_params to be handled differently
    # when regularization is applied
    attention = recognizer.generator.transition.attention
    attention_params = Selector(attention).get_parameters().values()

    logger.info(
        "Initialization schemes for all bricks.\n"
        "Works well only in my branch with __repr__ added to all them,\n"
        "there is an issue #463 in Blocks to do that properly.")

    def show_init_scheme(cur):
        result = dict()
        for attr in dir(cur):
            if attr.endswith('_init'):
                result[attr] = getattr(cur, attr)
        for child in cur.children:
            result[child.name] = show_init_scheme(child)
        return result

    logger.info(pprint.pformat(show_init_scheme(recognizer)))

    prediction, prediction_mask = add_exploration(recognizer, data, train_conf)

    #
    # Observables:
    #
    primary_observables = []  # monitored each batch
    secondary_observables = []  # monitored every 10 batches
    validation_observables = []  # monitored on the validation set

    cg = recognizer.get_cost_graph(batch=True,
                                   prediction=prediction,
                                   prediction_mask=prediction_mask)
    labels, = VariableFilter(applications=[recognizer.cost], name='labels')(cg)
    labels_mask, = VariableFilter(applications=[recognizer.cost],
                                  name='labels_mask')(cg)

    gain_matrix = VariableFilter(
        theano_name=RewardRegressionEmitter.GAIN_MATRIX)(cg)
    if len(gain_matrix):
        gain_matrix, = gain_matrix
        primary_observables.append(rename(gain_matrix.min(), 'min_gain'))
        primary_observables.append(rename(gain_matrix.max(), 'max_gain'))

    batch_cost = cg.outputs[0].sum()
    batch_size = rename(recognizer.labels.shape[1], "batch_size")
    # Assumes constant batch size. `aggregation.mean` is not used because
    # of Blocks #514.
    cost = batch_cost / batch_size
    cost.name = "sequence_total_cost"
    logger.info("Cost graph is built")

    # Fetch variables useful for debugging.
    # It is important not to use any aggregation schemes here,
    # as it's currently impossible to spread the effect of
    # regularization on their variables, see Blocks #514.
    cost_cg = ComputationGraph(cost)
    r = recognizer
    energies, = VariableFilter(applications=[r.generator.readout.readout],
                               name="output_0")(cost_cg)
    bottom_output = VariableFilter(
        # We need name_regex instead of name because LookupTable calls itsoutput output_0
        applications=[r.bottom.apply],
        name_regex="output")(cost_cg)[-1]
    attended, = VariableFilter(applications=[r.generator.transition.apply],
                               name="attended")(cost_cg)
    attended_mask, = VariableFilter(applications=[
        r.generator.transition.apply
    ],
                                    name="attended_mask")(cost_cg)
    weights, = VariableFilter(applications=[r.generator.evaluate],
                              name="weights")(cost_cg)

    from blocks.roles import AUXILIARY
    l2_cost, = VariableFilter(roles=[AUXILIARY],
                              theano_name='l2_cost_aux')(cost_cg)
    cost_forward, = VariableFilter(roles=[AUXILIARY],
                                   theano_name='costs_forward_aux')(cost_cg)

    max_recording_length = rename(bottom_output.shape[0],
                                  "max_recording_length")
    # To exclude subsampling related bugs
    max_attended_mask_length = rename(attended_mask.shape[0],
                                      "max_attended_mask_length")
    max_attended_length = rename(attended.shape[0], "max_attended_length")
    max_num_phonemes = rename(labels.shape[0], "max_num_phonemes")
    min_energy = rename(energies.min(), "min_energy")
    max_energy = rename(energies.max(), "max_energy")
    mean_attended = rename(abs(attended).mean(), "mean_attended")
    mean_bottom_output = rename(
        abs(bottom_output).mean(), "mean_bottom_output")
    weights_penalty = rename(monotonicity_penalty(weights, labels_mask),
                             "weights_penalty")
    weights_entropy = rename(entropy(weights, labels_mask), "weights_entropy")
    mask_density = rename(labels_mask.mean(), "mask_density")
    cg = ComputationGraph([
        cost, weights_penalty, weights_entropy, min_energy, max_energy,
        mean_attended, mean_bottom_output, batch_size, max_num_phonemes,
        mask_density
    ])
    # Regularization. It is applied explicitly to all variables
    # of interest, it could not be applied to the cost only as it
    # would not have effect on auxiliary variables, see Blocks #514.
    reg_config = config.get('regularization', dict())
    regularized_cg = cg
    if reg_config.get('dropout'):
        logger.info('apply dropout')
        regularized_cg = apply_dropout(cg, [bottom_output], 0.5)
    if reg_config.get('noise'):
        logger.info('apply noise')
        noise_subjects = [
            p for p in cg.parameters if p not in attention_params
        ]
        regularized_cg = apply_noise(cg, noise_subjects, reg_config['noise'])

    train_cost = regularized_cg.outputs[0]
    if reg_config.get("penalty_coof", .0) > 0:
        # big warning!!!
        # here we assume that:
        # regularized_weights_penalty = regularized_cg.outputs[1]
        train_cost = (train_cost + reg_config.get("penalty_coof", .0) *
                      regularized_cg.outputs[1] / batch_size)
    if reg_config.get("decay", .0) > 0:
        train_cost = (
            train_cost + reg_config.get("decay", .0) *
            l2_norm(VariableFilter(roles=[WEIGHT])(cg.parameters))**2)

    train_cost = rename(train_cost, 'train_cost')

    gradients = None
    if reg_config.get('adaptive_noise'):
        logger.info('apply adaptive noise')
        if ((reg_config.get("penalty_coof", .0) > 0)
                or (reg_config.get("decay", .0) > 0)):
            logger.error('using  adaptive noise with alignment weight panalty '
                         'or weight decay is probably stupid')
        train_cost, regularized_cg, gradients, noise_brick = apply_adaptive_noise(
            cg,
            cg.outputs[0],
            variables=cg.parameters,
            num_examples=data.get_dataset('train').num_examples,
            parameters=Model(
                regularized_cg.outputs[0]).get_parameter_dict().values(),
            **reg_config.get('adaptive_noise'))
        train_cost.name = 'train_cost'
        adapt_noise_cg = ComputationGraph(train_cost)
        model_prior_mean = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_mean')(adapt_noise_cg)[0],
            'model_prior_mean')
        model_cost = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_cost')(adapt_noise_cg)[0], 'model_cost')
        model_prior_variance = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_variance')(adapt_noise_cg)[0],
            'model_prior_variance')
        regularized_cg = ComputationGraph(
            [train_cost, model_cost] + regularized_cg.outputs +
            [model_prior_mean, model_prior_variance])
        primary_observables += [
            regularized_cg.outputs[1],  # model cost
            regularized_cg.outputs[2],  # task cost
            regularized_cg.outputs[-2],  # model prior mean
            regularized_cg.outputs[-1]
        ]  # model prior variance

    model = Model(train_cost)
    if params:
        logger.info("Load parameters from " + params)
        # please note: we cannot use recognizer.load_params
        # as it builds a new computation graph that dies not have
        # shapred variables added by adaptive weight noise
        with open(params, 'r') as src:
            param_values = load_parameters(src)
        model.set_parameter_values(param_values)

    parameters = model.get_parameter_dict()
    logger.info("Parameters:\n" +
                pprint.pformat([(key, parameters[key].get_value().shape)
                                for key in sorted(parameters.keys())],
                               width=120))

    # Define the training algorithm.
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'],
                                   train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(
            AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    max_norm_rules = []
    if reg_config.get('max_norm', False) > 0:
        logger.info("Apply MaxNorm")
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(cg.parameters)
        if reg_config.get('max_norm_exclude_lookup', False):
            maxnorm_subjects = [
                v for v in maxnorm_subjects
                if not isinstance(get_brick(v), LookupTable)
            ]
        logger.info("Parameters covered by MaxNorm:\n" + pprint.pformat(
            [name for name, p in parameters.items() if p in maxnorm_subjects]))
        logger.info("Parameters NOT covered by MaxNorm:\n" + pprint.pformat([
            name for name, p in parameters.items() if not p in maxnorm_subjects
        ]))
        max_norm_rules = [
            Restrict(VariableClipping(reg_config['max_norm'], axis=0),
                     maxnorm_subjects)
        ]
    burn_in = []
    if train_conf.get('burn_in_steps', 0):
        burn_in.append(BurnIn(num_steps=train_conf['burn_in_steps']))
    algorithm = GradientDescent(
        cost=train_cost,
        parameters=parameters.values(),
        gradients=gradients,
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)] + burn_in),
        on_unused_sources='warn')

    logger.debug("Scan Ops in the gradients")
    gradient_cg = ComputationGraph(algorithm.gradients.values())
    for op in ComputationGraph(gradient_cg).scans:
        logger.debug(op)

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    secondary_observables += list(regularized_cg.outputs)
    if not 'train_cost' in [v.name for v in secondary_observables]:
        secondary_observables += [train_cost]
    secondary_observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold
    ]
    for name, param in parameters.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements**0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements**0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements**0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        secondary_observables.append(stats)

    primary_observables += [
        train_cost, algorithm.total_gradient_norm, algorithm.total_step_norm,
        clipping.threshold, max_recording_length, max_attended_length,
        max_attended_mask_length
    ]

    validation_observables += [
        rename(aggregation.mean(batch_cost, batch_size), cost.name),
        rename(aggregation.sum_(batch_size), 'num_utterances'),
        weights_entropy, weights_penalty
    ]

    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name == 'weights_penalty':
                result.append(
                    rename(aggregation.mean(var, batch_size),
                           'weights_penalty_per_recording'))
            elif var.name == 'weights_entropy':
                result.append(
                    rename(aggregation.mean(var, labels_mask.sum()),
                           'weights_entropy_per_label'))
            else:
                result.append(var)
        return result

    mon_conf = config['monitoring']

    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(
            Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
    ]
    extensions.append(
        TrainingDataMonitoring(primary_observables + [l2_cost, cost_forward],
                               after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(secondary_observables),
        prefix="average",
        every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes(validation_observables +
                                   [l2_cost, cost_forward]),
        data.get_stream("valid", shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['validate_every_epochs'],
            every_n_batches=mon_conf['validate_every_batches'],
            after_training=False)
    extensions.append(validation)
    per = PhonemeErrorRate(recognizer, data, **config['monitoring']['search'])
    per_monitoring = DataStreamMonitoring(
        [per],
        data.get_stream("valid", batches=False, shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['search_every_epochs'],
            every_n_batches=mon_conf['search_every_batches'],
            after_training=False)
    extensions.append(per_monitoring)
    track_the_best_per = TrackTheBest(
        per_monitoring.record_name(per)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_cost = TrackTheBest(
        validation.record_name(cost)).set_conditions(before_first_epoch=True,
                                                     after_epoch=True)
    extensions += [track_the_best_cost, track_the_best_per]
    extensions.append(
        AdaptiveClipping(algorithm.total_gradient_norm.name,
                         clipping,
                         train_conf['gradient_threshold'],
                         decay_rate=0.998,
                         burnin_period=500))
    extensions += [
        SwitchOffLengthFilter(
            data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf.get('num_batches'),
                    after_n_epochs=train_conf.get('num_epochs')).add_condition(
                        ["after_batch"], _gradient_norm_is_none),
    ]
    channels = [
        # Plot 1: training and validation costs
        [
            average_monitoring.record_name(train_cost),
            validation.record_name(cost)
        ],
        # Plot 2: gradient norm,
        [
            average_monitoring.record_name(algorithm.total_gradient_norm),
            average_monitoring.record_name(clipping.threshold)
        ],
        # Plot 3: phoneme error rate
        [per_monitoring.record_name(per)],
        # Plot 4: training and validation mean weight entropy
        [
            average_monitoring._record_name('weights_entropy_per_label'),
            validation._record_name('weights_entropy_per_label')
        ],
        # Plot 5: training and validation monotonicity penalty
        [
            average_monitoring._record_name('weights_penalty_per_recording'),
            validation._record_name('weights_penalty_per_recording')
        ]
    ]
    if bokeh:
        extensions += [
            Plot(bokeh_name if bokeh_name else os.path.basename(save_path),
                 channels,
                 every_n_batches=10,
                 server_url=bokeh_server),
        ]
    extensions += [
        Checkpoint(save_path,
                   before_first_epoch=not fast_start,
                   after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True).add_condition(
                       ['after_epoch'],
                       OnLogRecord(track_the_best_per.notification_name),
                       (root_path + "_best" + extension, )).add_condition(
                           ['after_epoch'],
                           OnLogRecord(track_the_best_cost.notification_name),
                           (root_path + "_best_ll" + extension, )),
        ProgressBar()
    ]
    extensions.append(EmbedIPython(use_main_loop_run_caller_env=True))
    if config['net']['criterion']['name'].startswith('mse'):
        extensions.append(
            LogInputsGains(labels, cg, recognizer.generator.readout.emitter,
                           data))

    if train_conf.get('patience'):
        patience_conf = train_conf['patience']
        if not patience_conf.get('notification_names'):
            # setdefault will not work for empty list
            patience_conf['notification_names'] = [
                track_the_best_per.notification_name,
                track_the_best_cost.notification_name
            ]
        extensions.append(Patience(**patience_conf))

    extensions.append(
        Printing(every_n_batches=1, attribute_filter=PrintingFilterList()))

    return model, algorithm, data, extensions
コード例 #4
0
ファイル: main.py プロジェクト: ZhangAustin/attention-lvcsr
def train(config, save_path, bokeh_name,
          params, bokeh_server, test_tag, use_load_ext,
          load_log, fast_start, validation_epochs, validation_batches,
          per_epochs, per_batches):
    root_path, extension = os.path.splitext(save_path)

    data = Data(**config['data'])

    # Build the main brick and initialize all parameters.
    recognizer = SpeechRecognizer(
        data.recordings_source, data.labels_source,
        data.eos_label,
        data.num_features, data.num_labels,
        name="recognizer",
        data_prepend_eos=data.prepend_eos,
        character_map=data.character_map,
        **config["net"])
    for brick_path, attribute_dict in sorted(
            config['initialization'].items(),
            key=lambda (k, v): -k.count('/')):
        for attribute, value in attribute_dict.items():
            brick, = Selector(recognizer).select(brick_path).bricks
            setattr(brick, attribute, value)
            brick.push_initialization_config()
    recognizer.initialize()

    # Separate attention_params to be handled differently
    # when regularization is applied
    attention = recognizer.generator.transition.attention
    attention_params = Selector(attention).get_parameters().values()

    logger.info(
        "Initialization schemes for all bricks.\n"
        "Works well only in my branch with __repr__ added to all them,\n"
        "there is an issue #463 in Blocks to do that properly.")

    def show_init_scheme(cur):
        result = dict()
        for attr in dir(cur):
            if attr.endswith('_init'):
                result[attr] = getattr(cur, attr)
        for child in cur.children:
            result[child.name] = show_init_scheme(child)
        return result
    logger.info(pprint.pformat(show_init_scheme(recognizer)))

    if params:
        logger.info("Load parameters from " + params)
        recognizer.load_params(params)

    if test_tag:
        tensor.TensorVariable.__str__ = tensor.TensorVariable.__repr__
        __stream = data.get_stream("train")
        __data = next(__stream.get_epoch_iterator(as_dict=True))
        recognizer.recordings.tag.test_value = __data[data.recordings_source]
        recognizer.recordings_mask.tag.test_value = __data[data.recordings_source + '_mask']
        recognizer.labels.tag.test_value = __data[data.labels_source]
        recognizer.labels_mask.tag.test_value = __data[data.labels_source + '_mask']
        theano.config.compute_test_value = 'warn'

    batch_cost = recognizer.get_cost_graph().sum()
    batch_size = named_copy(recognizer.recordings.shape[1], "batch_size")
    # Assumes constant batch size. `aggregation.mean` is not used because
    # of Blocks #514.
    cost = batch_cost / batch_size
    cost.name = "sequence_log_likelihood"
    logger.info("Cost graph is built")

    # Fetch variables useful for debugging.
    # It is important not to use any aggregation schemes here,
    # as it's currently impossible to spread the effect of
    # regularization on their variables, see Blocks #514.
    cost_cg = ComputationGraph(cost)
    r = recognizer
    energies, = VariableFilter(
        applications=[r.generator.readout.readout], name="output_0")(
                cost_cg)
    bottom_output, = VariableFilter(
        applications=[r.bottom.apply], name="output")(
                cost_cg)
    attended, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended")(
                cost_cg)
    attended_mask, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended_mask")(
                cost_cg)
    weights, = VariableFilter(
        applications=[r.generator.evaluate], name="weights")(
                cost_cg)
    max_recording_length = named_copy(r.recordings.shape[0],
                                      "max_recording_length")
    # To exclude subsampling related bugs
    max_attended_mask_length = named_copy(attended_mask.shape[0],
                                          "max_attended_mask_length")
    max_attended_length = named_copy(attended.shape[0],
                                     "max_attended_length")
    max_num_phonemes = named_copy(r.labels.shape[0],
                                  "max_num_phonemes")
    min_energy = named_copy(energies.min(), "min_energy")
    max_energy = named_copy(energies.max(), "max_energy")
    mean_attended = named_copy(abs(attended).mean(),
                               "mean_attended")
    mean_bottom_output = named_copy(abs(bottom_output).mean(),
                                    "mean_bottom_output")
    weights_penalty = named_copy(monotonicity_penalty(weights, r.labels_mask),
                                 "weights_penalty")
    weights_entropy = named_copy(entropy(weights, r.labels_mask),
                                 "weights_entropy")
    mask_density = named_copy(r.labels_mask.mean(),
                              "mask_density")
    cg = ComputationGraph([
        cost, weights_penalty, weights_entropy,
        min_energy, max_energy,
        mean_attended, mean_bottom_output,
        batch_size, max_num_phonemes,
        mask_density])

    # Regularization. It is applied explicitly to all variables
    # of interest, it could not be applied to the cost only as it
    # would not have effect on auxiliary variables, see Blocks #514.
    reg_config = config['regularization']
    regularized_cg = cg
    if reg_config.get('dropout'):
        logger.info('apply dropout')
        regularized_cg = apply_dropout(cg, [bottom_output], 0.5)
    if reg_config.get('noise'):
        logger.info('apply noise')
        noise_subjects = [p for p in cg.parameters if p not in attention_params]
        regularized_cg = apply_noise(cg, noise_subjects, reg_config['noise'])
    regularized_cost = regularized_cg.outputs[0]
    regularized_weights_penalty = regularized_cg.outputs[1]

    # Model is weird class, we spend lots of time arguing with Bart
    # what it should be. However it can already nice things, e.g.
    # one extract all the parameters from the computation graphs
    # and give them hierahical names. This help to notice when a
    # because of some bug a parameter is not in the computation
    # graph.
    model = SpeechModel(regularized_cost)
    params = model.get_parameter_dict()
    logger.info("Parameters:\n" +
                pprint.pformat(
                    [(key, params[key].get_value().shape) for key
                        in sorted(params.keys())],
                    width=120))

    # Define the training algorithm.
    train_conf = config['training']
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'], train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    max_norm_rules = []
    if reg_config.get('max_norm', False):
        logger.info("Apply MaxNorm")
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(cg.parameters)
        if reg_config.get('max_norm_exclude_lookup', False):
            maxnorm_subjects = [v for v in maxnorm_subjects
                                if not isinstance(get_brick(v), LookupTable)]
        logger.info("Parameters covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in params.items()
                                        if p in maxnorm_subjects]))
        logger.info("Parameters NOT covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in params.items()
                                        if not p in maxnorm_subjects]))
        max_norm_rules = [
            Restrict(VariableClipping(reg_config['max_norm'], axis=0),
                        maxnorm_subjects)]
    algorithm = GradientDescent(
        cost=regularized_cost +
            reg_config.get("penalty_coof", .0) * regularized_weights_penalty / batch_size +
            reg_config.get("decay", .0) *
            l2_norm(VariableFilter(roles=[WEIGHT])(cg.parameters)) ** 2,
        parameters=params.values(),
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)]))

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    observables = regularized_cg.outputs
    observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold]
    for name, param in params.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements ** 0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements ** 0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements ** 0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        observables.append(stats)

    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name == 'weights_penalty':
                result.append(named_copy(aggregation.mean(var, batch_size),
                                            'weights_penalty_per_recording'))
            elif var.name == 'weights_entropy':
                result.append(named_copy(aggregation.mean(
                    var, recognizer.labels_mask.sum()), 'weights_entropy_per_label'))
            else:
                result.append(var)
        return result

    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
        ]
    extensions.append(TrainingDataMonitoring(
        [observables[0], algorithm.total_gradient_norm,
            algorithm.total_step_norm, clipping.threshold,
            max_recording_length,
            max_attended_length, max_attended_mask_length], after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(observables),
        prefix="average", every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes([cost, weights_entropy, weights_penalty]),
        data.get_stream("valid"), prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=validation_epochs,
            every_n_batches=validation_batches,
            after_training=False)
    extensions.append(validation)
    recognizer.init_beam_search(10)
    per = PhonemeErrorRate(recognizer, data.get_dataset("valid"))
    per_monitoring = DataStreamMonitoring(
        [per], data.get_stream("valid", batches=False, shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=per_epochs,
            every_n_batches=per_batches,
            after_training=False)
    extensions.append(per_monitoring)
    track_the_best_per = TrackTheBest(
        per_monitoring.record_name(per)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_likelihood = TrackTheBest(
        validation.record_name(cost)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    extensions += [track_the_best_likelihood, track_the_best_per]
    extensions.append(AdaptiveClipping(
        algorithm.total_gradient_norm.name,
        clipping, train_conf['gradient_threshold'],
        decay_rate=0.998, burnin_period=500))
    extensions += [
        SwitchOffLengthFilter(data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf['num_batches'],
                    after_n_epochs=train_conf['num_epochs'])
        .add_condition(["after_batch"], _gradient_norm_is_none),
        # Live plotting: requires launching `bokeh-server`
        # and allows to see what happens online.
        Plot(bokeh_name
             if bokeh_name
             else os.path.basename(save_path),
             [# Plot 1: training and validation costs
             [average_monitoring.record_name(regularized_cost),
             validation.record_name(cost)],
             # Plot 2: gradient norm,
             [average_monitoring.record_name(algorithm.total_gradient_norm),
             average_monitoring.record_name(clipping.threshold)],
             # Plot 3: phoneme error rate
             [per_monitoring.record_name(per)],
             # Plot 4: training and validation mean weight entropy
             [average_monitoring._record_name('weights_entropy_per_label'),
             validation._record_name('weights_entropy_per_label')],
             # Plot 5: training and validation monotonicity penalty
             [average_monitoring._record_name('weights_penalty_per_recording'),
             validation._record_name('weights_penalty_per_recording')]],
             every_n_batches=10,
             server_url=bokeh_server),
        Checkpoint(save_path,
                   before_first_epoch=not fast_start, after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True)
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_per.notification_name),
            (root_path + "_best" + extension,))
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_likelihood.notification_name),
            (root_path + "_best_ll" + extension,)),
        ProgressBar(),
        Printing(every_n_batches=1,
                    attribute_filter=PrintingFilterList()
                    )]

    # Save the config into the status
    log = TrainingLog()
    log.status['_config'] = repr(config)
    main_loop = MainLoop(
        model=model, log=log, algorithm=algorithm,
        data_stream=data.get_stream("train"),
        extensions=extensions)
    main_loop.run()
コード例 #5
0
algorithm = GradientDescent(cost=cost, parameters=cg.parameters,
                            step_rule=CompositeRule(step_rules))

# Extensions
gradient_norm = aggregation.mean(algorithm.total_gradient_norm)
step_norm = aggregation.mean(algorithm.total_step_norm)
monitored_vars = [cost, gradient_norm, step_norm]

dev_monitor = DataStreamMonitoring(variables=[cost], after_batch=True,
                                   before_first_epoch=True, data_stream=dev_stream, prefix="dev")
train_monitor = TrainingDataMonitoring(variables=monitored_vars, after_batch=True,
                                       before_first_epoch=True, prefix='train')

plotter = Plot('RNN char-level prediction',
	       channels=[[train_monitor.record_name(cost)],
                  [dev_monitor.record_name(cost)]],
	       server_url="http://bart4.iro.umontreal.ca:5006",
	       after_batch=True)


if start_from_checkpoint == True:
    extensions = [dev_monitor, train_monitor, Timing(), Printing(after_batch=True),
                FinishAfter(after_n_epochs=num_epochs),
                saveload.Load(last_path, load_log=True), plotter,
                saveload.Checkpoint(last_path, save_separately=['log']),
                ] + track_best('dev_cost', save_path)
else: #start fresh
    extensions = [dev_monitor, train_monitor, Timing(), Printing(after_batch=True),
                FinishAfter(after_n_epochs=num_epochs),
                saveload.Load(load_path), plotter,
                saveload.Checkpoint(last_path, save_separately=['log']),
コード例 #6
0
def train_snli_model(new_training_job,
                     config,
                     save_path,
                     params,
                     fast_start,
                     fuel_server,
                     seed,
                     model='simple'):
    if config['exclude_top_k'] > config['num_input_words'] and config[
            'num_input_words'] > 0:
        raise Exception("Some words have neither word nor def embedding")
    c = config
    logger = configure_logger(name="snli_baseline_training",
                              log_file=os.path.join(save_path, "log.txt"))
    if not os.path.exists(save_path):
        logger.info("Start a new job")
        os.mkdir(save_path)
    else:
        logger.info("Continue an existing job")
    with open(os.path.join(save_path, "cmd.txt"), "w") as f:
        f.write(" ".join(sys.argv))

    # Make data paths nice
    for path in [
            'dict_path', 'embedding_def_path', 'embedding_path', 'vocab',
            'vocab_def', 'vocab_text'
    ]:
        if c.get(path, ''):
            if not os.path.isabs(c[path]):
                c[path] = os.path.join(fuel.config.data_path[0], c[path])

    main_loop_path = os.path.join(save_path, 'main_loop.tar')
    main_loop_best_val_path = os.path.join(save_path, 'main_loop_best_val.tar')
    stream_path = os.path.join(save_path, 'stream.pkl')

    # Save config to save_path
    json.dump(config, open(os.path.join(save_path, "config.json"), "w"))

    if model == 'simple':
        nli_model, data, used_dict, used_retrieval, _ = _initialize_simple_model_and_data(
            c)
    elif model == 'esim':
        nli_model, data, used_dict, used_retrieval, _ = _initialize_esim_model_and_data(
            c)
    else:
        raise NotImplementedError()

    # Compute cost
    s1, s2 = T.lmatrix('sentence1'), T.lmatrix('sentence2')

    if c['dict_path']:
        assert os.path.exists(c['dict_path'])
        s1_def_map, s2_def_map = T.lmatrix('sentence1_def_map'), T.lmatrix(
            'sentence2_def_map')
        def_mask = T.fmatrix("def_mask")
        defs = T.lmatrix("defs")
    else:
        s1_def_map, s2_def_map = None, None
        def_mask = None
        defs = None

    s1_mask, s2_mask = T.fmatrix('sentence1_mask'), T.fmatrix('sentence2_mask')
    y = T.ivector('label')

    cg = {}
    for train_phase in [True, False]:
        # NOTE: Please don't change outputs of cg
        if train_phase:
            with batch_normalization(nli_model):
                pred = nli_model.apply(s1,
                                       s1_mask,
                                       s2,
                                       s2_mask,
                                       def_mask=def_mask,
                                       defs=defs,
                                       s1_def_map=s1_def_map,
                                       s2_def_map=s2_def_map,
                                       train_phase=train_phase)
        else:
            pred = nli_model.apply(s1,
                                   s1_mask,
                                   s2,
                                   s2_mask,
                                   def_mask=def_mask,
                                   defs=defs,
                                   s1_def_map=s1_def_map,
                                   s2_def_map=s2_def_map,
                                   train_phase=train_phase)

        cost = CategoricalCrossEntropy().apply(y.flatten(), pred)
        error_rate = MisclassificationRate().apply(y.flatten(), pred)
        cg[train_phase] = ComputationGraph([cost, error_rate])

    # Weight decay (TODO: Make it less bug prone)
    if model == 'simple':
        weights_to_decay = VariableFilter(
            bricks=[dense for dense, relu, bn in nli_model._mlp],
            roles=[WEIGHT])(cg[True].variables)
        weight_decay = np.float32(c['l2']) * sum(
            (w**2).sum() for w in weights_to_decay)
    elif model == 'esim':
        weight_decay = 0.0
    else:
        raise NotImplementedError()

    final_cost = cg[True].outputs[0] + weight_decay
    final_cost.name = 'final_cost'

    # Add updates for population parameters

    if c.get("bn", True):
        pop_updates = get_batch_normalization_updates(cg[True])
        extra_updates = [(p, m * 0.1 + p * (1 - 0.1)) for p, m in pop_updates]
    else:
        pop_updates = []
        extra_updates = []

    if params:
        logger.debug("Load parameters from {}".format(params))
        with open(params) as src:
            loaded_params = load_parameters(src)
            cg[True].set_parameter_values(loaded_params)
            for param, m in pop_updates:
                param.set_value(loaded_params[get_brick(
                    param).get_hierarchical_name(param)])

    if os.path.exists(os.path.join(save_path, "main_loop.tar")):
        logger.warning("Manually loading BN stats :(")
        with open(os.path.join(save_path, "main_loop.tar")) as src:
            loaded_params = load_parameters(src)

        for param, m in pop_updates:
            param.set_value(
                loaded_params[get_brick(param).get_hierarchical_name(param)])

    if theano.config.compute_test_value != 'off':
        test_value_data = next(
            data.get_stream('train', batch_size=4).get_epoch_iterator())
        s1.tag.test_value = test_value_data[0]
        s1_mask.tag.test_value = test_value_data[1]
        s2.tag.test_value = test_value_data[2]
        s2_mask.tag.test_value = test_value_data[3]
        y.tag.test_value = test_value_data[4]

    # Freeze embeddings
    if not c['train_emb']:
        frozen_params = [
            p for E in nli_model.get_embeddings_lookups() for p in E.parameters
        ]
        train_params = [p for p in cg[True].parameters]
        assert len(set(frozen_params) & set(train_params)) > 0
    else:
        frozen_params = []
    if not c.get('train_def_emb', 1):
        frozen_params_def = [
            p for E in nli_model.get_def_embeddings_lookups()
            for p in E.parameters
        ]
        train_params = [p for p in cg[True].parameters]
        assert len(set(frozen_params_def) & set(train_params)) > 0
        frozen_params += frozen_params_def
    train_params = [p for p in cg[True].parameters if p not in frozen_params]
    train_params_keys = [
        get_brick(p).get_hierarchical_name(p) for p in train_params
    ]

    # Optimizer
    algorithm = GradientDescent(cost=final_cost,
                                on_unused_sources='ignore',
                                parameters=train_params,
                                step_rule=Adam(learning_rate=c['lr']))
    algorithm.add_updates(extra_updates)
    m = Model(final_cost)

    parameters = m.get_parameter_dict()  # Blocks version mismatch
    logger.info("Trainable parameters" + "\n" +
                pprint.pformat([(key, parameters[key].get_value().shape)
                                for key in sorted(train_params_keys)],
                               width=120))
    logger.info("# of parameters {}".format(
        sum([
            np.prod(parameters[key].get_value().shape)
            for key in sorted(train_params_keys)
        ])))

    ### Monitored args ###
    train_monitored_vars = [final_cost] + cg[True].outputs
    monitored_vars = cg[False].outputs
    val_acc = monitored_vars[1]
    to_monitor_names = [
        'def_unk_ratio', 's1_merged_input_rootmean2', 's1_def_mean_rootmean2',
        's1_gate_rootmean2', 's1_compose_gate_rootmean2'
    ]
    for k in to_monitor_names:
        train_v, valid_v = VariableFilter(name=k)(
            cg[True]), VariableFilter(name=k)(cg[False])
        if len(train_v):
            logger.info("Adding {} tracking".format(k))
            train_monitored_vars.append(train_v[0])
            monitored_vars.append(valid_v[0])
        else:
            logger.warning("Didnt find {} in cg".format(k))

    if c['monitor_parameters']:
        for name in train_params_keys:
            param = parameters[name]
            num_elements = numpy.product(param.get_value().shape)
            norm = param.norm(2) / num_elements
            grad_norm = algorithm.gradients[param].norm(2) / num_elements
            step_norm = algorithm.steps[param].norm(2) / num_elements
            stats = tensor.stack(norm, grad_norm, step_norm,
                                 step_norm / grad_norm)
            stats.name = name + '_stats'
            train_monitored_vars.append(stats)

    regular_training_stream = data.get_stream('train',
                                              batch_size=c['batch_size'],
                                              seed=seed)

    if fuel_server:
        # the port will be configured by the StartFuelServer extension
        training_stream = ServerDataStream(
            sources=regular_training_stream.sources,
            hwm=100,
            produces_examples=regular_training_stream.produces_examples)
    else:
        training_stream = regular_training_stream

    ### Build extensions ###

    extensions = [
        # Load(main_loop_path, load_iteration_state=True, load_log=True)
        #     .set_conditions(before_training=not new_training_job),
        StartFuelServer(regular_training_stream,
                        stream_path,
                        hwm=100,
                        script_path=os.path.join(
                            os.path.dirname(__file__),
                            "../bin/start_fuel_server.py"),
                        before_training=fuel_server),
        Timing(every_n_batches=c['mon_freq']),
        ProgressBar(),
        RetrievalPrintStats(retrieval=used_retrieval,
                            every_n_batches=c['mon_freq_valid'],
                            before_training=not fast_start),
        Timestamp(),
        TrainingDataMonitoring(train_monitored_vars,
                               prefix="train",
                               every_n_batches=c['mon_freq']),
    ]

    if c['layout'] == 'snli':
        validation = DataStreamMonitoring(monitored_vars,
                                          data.get_stream('valid',
                                                          batch_size=14,
                                                          seed=seed),
                                          before_training=not fast_start,
                                          on_resumption=True,
                                          after_training=True,
                                          every_n_batches=c['mon_freq_valid'],
                                          prefix='valid')
        extensions.append(validation)
    elif c['layout'] == 'mnli':
        validation = DataStreamMonitoring(monitored_vars,
                                          data.get_stream('valid_matched',
                                                          batch_size=14,
                                                          seed=seed),
                                          every_n_batches=c['mon_freq_valid'],
                                          on_resumption=True,
                                          after_training=True,
                                          prefix='valid_matched')
        validation_mismatched = DataStreamMonitoring(
            monitored_vars,
            data.get_stream('valid_mismatched', batch_size=14, seed=seed),
            every_n_batches=c['mon_freq_valid'],
            before_training=not fast_start,
            on_resumption=True,
            after_training=True,
            prefix='valid_mismatched')
        extensions.extend([validation, validation_mismatched])
    else:
        raise NotImplementedError()

    # Similarity trackers for embeddings
    if len(c.get('vocab_def', '')):
        retrieval_vocab = Vocabulary(c['vocab_def'])
    else:
        retrieval_vocab = data.vocab

    retrieval_all = Retrieval(vocab_text=retrieval_vocab,
                              dictionary=used_dict,
                              max_def_length=c['max_def_length'],
                              exclude_top_k=0,
                              max_def_per_word=c['max_def_per_word'])

    for name in [
            's1_word_embeddings', 's1_dict_word_embeddings',
            's1_translated_word_embeddings'
    ]:
        variables = VariableFilter(name=name)(cg[False])
        if len(variables):
            s1_emb = variables[0]
            logger.info("Adding similarity tracking for " + name)
            # A bit sloppy about downcast

            if "dict" in name:
                embedder = construct_dict_embedder(theano.function(
                    [s1, defs, def_mask, s1_def_map],
                    s1_emb,
                    allow_input_downcast=True),
                                                   vocab=data.vocab,
                                                   retrieval=retrieval_all)
                extensions.append(
                    SimilarityWordEmbeddingEval(
                        embedder=embedder,
                        prefix=name,
                        every_n_batches=c['mon_freq_valid'],
                        before_training=not fast_start))
            else:
                embedder = construct_embedder(theano.function(
                    [s1], s1_emb, allow_input_downcast=True),
                                              vocab=data.vocab)
                extensions.append(
                    SimilarityWordEmbeddingEval(
                        embedder=embedder,
                        prefix=name,
                        every_n_batches=c['mon_freq_valid'],
                        before_training=not fast_start))

    track_the_best = TrackTheBest(validation.record_name(val_acc),
                                  before_training=not fast_start,
                                  every_n_epochs=c['save_freq_epochs'],
                                  after_training=not fast_start,
                                  every_n_batches=c['mon_freq_valid'],
                                  choose_best=min)
    extensions.append(track_the_best)

    # Special care for serializing embeddings
    if len(c.get('embedding_path', '')) or len(c.get('embedding_def_path',
                                                     '')):
        extensions.insert(
            0,
            LoadNoUnpickling(main_loop_path,
                             load_iteration_state=True,
                             load_log=True).set_conditions(
                                 before_training=not new_training_job))
        extensions.append(
            Checkpoint(main_loop_path,
                       parameters=train_params + [p for p, m in pop_updates],
                       save_main_loop=False,
                       save_separately=['log', 'iteration_state'],
                       before_training=not fast_start,
                       every_n_epochs=c['save_freq_epochs'],
                       after_training=not fast_start).add_condition(
                           ['after_batch', 'after_epoch'],
                           OnLogRecord(track_the_best.notification_name),
                           (main_loop_best_val_path, )))
    else:
        extensions.insert(
            0,
            Load(main_loop_path, load_iteration_state=True,
                 load_log=True).set_conditions(
                     before_training=not new_training_job))
        extensions.append(
            Checkpoint(main_loop_path,
                       parameters=cg[True].parameters +
                       [p for p, m in pop_updates],
                       before_training=not fast_start,
                       every_n_epochs=c['save_freq_epochs'],
                       after_training=not fast_start).add_condition(
                           ['after_batch', 'after_epoch'],
                           OnLogRecord(track_the_best.notification_name),
                           (main_loop_best_val_path, )))

    extensions.extend([
        DumpCSVSummaries(save_path,
                         every_n_batches=c['mon_freq_valid'],
                         after_training=True),
        DumpTensorflowSummaries(save_path,
                                after_epoch=True,
                                every_n_batches=c['mon_freq_valid'],
                                after_training=True),
        Printing(every_n_batches=c['mon_freq_valid']),
        PrintMessage(msg="save_path={}".format(save_path),
                     every_n_batches=c['mon_freq']),
        FinishAfter(after_n_batches=c['n_batches']).add_condition(
            ['after_batch'],
            OnLogStatusExceed('iterations_done', c['n_batches']))
    ])

    logger.info(extensions)

    ### Run training ###

    if "VISDOM_SERVER" in os.environ:
        print("Running visdom server")
        ret = subprocess.Popen([
            os.path.join(os.path.dirname(__file__), "../visdom_plotter.py"),
            "--visdom-server={}".format(os.environ['VISDOM_SERVER']),
            "--folder={}".format(save_path)
        ])
        time.sleep(0.1)
        if ret.returncode is not None:
            raise Exception()
        atexit.register(lambda: os.kill(ret.pid, signal.SIGINT))

    model = Model(cost)
    for p, m in pop_updates:
        model._parameter_dict[get_brick(p).get_hierarchical_name(p)] = p

    main_loop = MainLoop(algorithm,
                         training_stream,
                         model=model,
                         extensions=extensions)

    assert os.path.exists(save_path)
    main_loop.run()
コード例 #7
0
def main(mode, save_path, steps, num_batches, load_params):
    chars = (list(string.ascii_uppercase) + list(range(10)) +
             [' ', '.', ',', '\'', '"', '!', '?', '<UNK>'])
    char_to_ind = {char: i for i, char in enumerate(chars)}
    ind_to_char = {v: k for k, v in char_to_ind.iteritems()}

    train_dataset = TextFile(['/Tmp/serdyuk/data/wsj_text_train'],
                             char_to_ind, bos_token=None, eos_token=None,
                             level='character')
    valid_dataset = TextFile(['/Tmp/serdyuk/data/wsj_text_valid'],
                             char_to_ind, bos_token=None, eos_token=None,
                             level='character')

    vocab_size = len(char_to_ind)
    logger.info('Dictionary size: {}'.format(vocab_size))
    if mode == 'continue':
        continue_training(save_path)
        return
    elif mode == "sample":
        main_loop = load(open(save_path, "rb"))
        generator = main_loop.model.get_top_bricks()[-1]

        sample = ComputationGraph(generator.generate(
            n_steps=steps, batch_size=1, iterate=True)).get_theano_function()

        states, outputs, costs = [data[:, 0] for data in sample()]
        print("".join([ind_to_char[s] for s in outputs]))

        numpy.set_printoptions(precision=3, suppress=True)
        print("Generation cost:\n{}".format(costs.sum()))

        freqs = numpy.bincount(outputs).astype(floatX)
        freqs /= freqs.sum()

        trans_freqs = numpy.zeros((vocab_size, vocab_size), dtype=floatX)
        for a, b in zip(outputs, outputs[1:]):
            trans_freqs[a, b] += 1
        trans_freqs /= trans_freqs.sum(axis=1)[:, None]
        return

    # Experiment configuration
    batch_size = 20
    dim = 650
    feedback_dim = 650

    valid_stream = valid_dataset.get_example_stream()
    valid_stream = Batch(valid_stream,
                         iteration_scheme=ConstantScheme(batch_size))
    valid_stream = Padding(valid_stream)
    valid_stream = Mapping(valid_stream, _transpose)

    # Build the bricks and initialize them

    transition = GatedRecurrent(name="transition", dim=dim,
                                activation=Tanh())
    generator = SequenceGenerator(
        Readout(readout_dim=vocab_size, source_names=transition.apply.states,
                emitter=SoftmaxEmitter(name="emitter"),
                feedback_brick=LookupFeedback(
                    vocab_size, feedback_dim, name='feedback'),
                name="readout"),
        transition,
        weights_init=Uniform(std=0.04), biases_init=Constant(0),
        name="generator")
    generator.push_initialization_config()
    transition.weights_init = Orthogonal()
    transition.push_initialization_config()
    generator.initialize()

    # Build the cost computation graph.
    features = tensor.lmatrix('features')
    features_mask = tensor.matrix('features_mask')
    cost_matrix = generator.cost_matrix(
        features, mask=features_mask)
    batch_cost = cost_matrix.sum()
    cost = aggregation.mean(
        batch_cost,
        features.shape[1])
    cost.name = "sequence_log_likelihood"
    char_cost = aggregation.mean(
        batch_cost, features_mask.sum())
    char_cost.name = 'character_log_likelihood'
    ppl = 2 ** (cost / numpy.log(2))
    ppl.name = 'ppl'
    bits_per_char = char_cost / tensor.log(2)
    bits_per_char.name = 'bits_per_char'
    length = features.shape[0]
    length.name = 'length'

    model = Model(batch_cost)
    if load_params:
        params = load_parameter_values(save_path)
        model.set_parameter_values(params)

    if mode == "train":
        # Give an idea of what's going on.
        logger.info("Parameters:\n" +
                    pprint.pformat(
                        [(key, value.get_value().shape) for key, value
                         in Selector(generator).get_parameters().items()],
                        width=120))

        train_stream = train_dataset.get_example_stream()
        train_stream = Mapping(train_stream, _truncate)
        train_stream = Batch(train_stream,
                             iteration_scheme=ConstantScheme(batch_size))
        train_stream = Padding(train_stream)
        train_stream = Mapping(train_stream, _transpose)

        parameters = model.get_parameter_dict()
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(parameters.values())
        algorithm = GradientDescent(
            cost=batch_cost,
            parameters=parameters.values(),
            step_rule=CompositeRule([StepClipping(1000.), 
                AdaDelta(epsilon=1e-8) #, Restrict(VariableClipping(1.0, axis=0), maxnorm_subjects)
                                     ]))
        ft = features[:6, 0]
        ft.name = 'feature_example'

        observables = [cost, ppl, char_cost, length, bits_per_char]
        for name, param in parameters.items():
            num_elements = numpy.product(param.get_value().shape)
            norm = param.norm(2) / num_elements ** 0.5
            grad_norm = algorithm.gradients[param].norm(2) / num_elements ** 0.5
            step_norm = algorithm.steps[param].norm(2) / num_elements ** 0.5
            stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
            stats.name = name + '_stats'
            observables.append(stats)
        track_the_best_bpc = TrackTheBest('valid_bits_per_char')
        root_path, extension = os.path.splitext(save_path)

        this_step_monitoring = TrainingDataMonitoring(
            observables + [ft], prefix="this_step", after_batch=True)
        average_monitoring = TrainingDataMonitoring(
            observables + [algorithm.total_step_norm,
                           algorithm.total_gradient_norm], 
            prefix="average",
            every_n_batches=10)
        valid_monitoring = DataStreamMonitoring(
            observables, prefix="valid",
            every_n_batches=1500, before_training=False,
            data_stream=valid_stream)
        main_loop = MainLoop(
            algorithm=algorithm,
            data_stream=train_stream,
            model=model,
            extensions=[
                this_step_monitoring,
                average_monitoring,
                valid_monitoring,
                track_the_best_bpc,
                Checkpoint(save_path, ),
                Checkpoint(save_path,
                           every_n_batches=500,
                           save_separately=["model", "log"],
                           use_cpickle=True)
                    .add_condition(
                    ['after_epoch'],
                    OnLogRecord(track_the_best_bpc.notification_name),
                    (root_path + "_best" + extension,)),
                Timing(after_batch=True),
                Printing(every_n_batches=10),
                Plot(root_path,
                     [[average_monitoring.record_name(cost),
                       valid_monitoring.record_name(cost)],
                      [average_monitoring.record_name(algorithm.total_step_norm)],
                      [average_monitoring.record_name(algorithm.total_gradient_norm)],
                      [average_monitoring.record_name(ppl),
                       valid_monitoring.record_name(ppl)],
                      [average_monitoring.record_name(char_cost),
                       valid_monitoring.record_name(char_cost)],
                      [average_monitoring.record_name(bits_per_char),
                       valid_monitoring.record_name(bits_per_char)]],
                     every_n_batches=10)
            ])
        main_loop.run()

    elif mode == 'evaluate':
        with open('/data/lisatmp3/serdyuk/wsj_lms/lms/wsj_trigram_with_initial_eos/lexicon.txt') as f:
            raw_words = [line.split()[1:-1] for line in f.readlines()]
            words = [[char_to_ind[c] if c in char_to_ind else char_to_ind['<UNK>'] for c in w] 
                     for w in raw_words]
        max_word_length = max([len(w) for w in words])
        
        initial_states = tensor.matrix('init_states')
        cost_matrix_step = generator.cost_matrix(features, mask=features_mask,
                                                 states=initial_states)
        cg = ComputationGraph(cost_matrix_step)
        states = cg.auxiliary_variables[-2]
        compute_cost = theano.function([features, features_mask, initial_states], 
                                       [cost_matrix_step.sum(axis=0), states])

        cost_matrix = generator.cost_matrix(features, mask=features_mask)
        initial_cg = ComputationGraph(cost_matrix)
        initial_states = initial_cg.auxiliary_variables[-2]

        total_word_cost = 0
        num_words = 0
        examples = numpy.zeros((max_word_length + 1, len(words)),
                               dtype='int64')
        all_masks = numpy.zeros((max_word_length + 1, len(words)),
                                dtype=floatX)

        for i, word in enumerate(words):
            examples[:len(word), i] = word
            all_masks[:len(word), i] = 1.

        single_space = numpy.array([char_to_ind[' ']])[:, None]

        for batch in valid_stream.get_epoch_iterator():
            for example, mask in equizip(batch[0].T, batch[1].T):
                example = example[:(mask.sum())]
                spc_inds = list(numpy.where(example == char_to_ind[" "])[0])
                state = generator.transition.transition.initial_states_.get_value()[None, :]
                for i, j in equizip([-1] + spc_inds, spc_inds + [-1]):
                    word = example[(i+1):j, None]
                    word_cost, states = compute_cost(
                        word, numpy.ones_like(word, dtype=floatX), state)
                    state = states[-1]

                    costs = numpy.exp(-compute_cost(
                        examples, all_masks, numpy.tile(state, [examples.shape[1], 1]))[0])

                    _, space_states = compute_cost(
                        single_space, numpy.ones_like(single_space, dtype=floatX), state)
                    state = space_states[-1]

                    word_prob = numpy.exp(-word_cost)
                    total_word_cost += word_cost + numpy.log(numpy.sum(costs))
                    num_words += 1
                    print(word_prob)
                    print(numpy.sum(costs))
                    print("Average cost", total_word_cost / num_words)
                    print("PPL", numpy.exp(total_word_cost / num_words))

        print("Word-level perplexity")
        print(total_word_cost / num_words)
    else:
        assert False
コード例 #8
0
def train_extractive_qa(new_training_job, config, save_path, params,
                        fast_start, fuel_server, seed):
    if seed:
        fuel.config.default_seed = seed
        blocks.config.config.default_seed = seed

    root_path = os.path.join(save_path, 'training_state')
    extension = '.tar'
    tar_path = root_path + extension
    best_tar_path = root_path + '_best' + extension

    c = config
    data, qam = initialize_data_and_model(c)

    if theano.config.compute_test_value != 'off':
        test_value_data = next(
            data.get_stream('train', shuffle=True, batch_size=4,
                            max_length=5).get_epoch_iterator(as_dict=True))
        for var in qam.input_vars.values():
            var.tag.test_value = test_value_data[var.name]

    costs = qam.apply_with_default_vars()
    cost = rename(costs.mean(), 'mean_cost')

    cg = Model(cost)
    if params:
        logger.debug("Load parameters from {}".format(params))
        with open(params) as src:
            cg.set_parameter_values(load_parameters(src))

    length = rename(qam.contexts.shape[1], 'length')
    batch_size = rename(qam.contexts.shape[0], 'batch_size')
    predicted_begins, = VariableFilter(name='predicted_begins')(cg)
    predicted_ends, = VariableFilter(name='predicted_ends')(cg)
    exact_match, = VariableFilter(name='exact_match')(cg)
    exact_match_ratio = rename(exact_match.mean(), 'exact_match_ratio')
    context_unk_ratio, = VariableFilter(name='context_unk_ratio')(cg)
    monitored_vars = [
        length, batch_size, cost, exact_match_ratio, context_unk_ratio
    ]
    if c['dict_path']:
        def_unk_ratio, = VariableFilter(name='def_unk_ratio')(cg)
        num_definitions = rename(qam.input_vars['defs'].shape[0],
                                 'num_definitions')
        max_definition_length = rename(qam.input_vars['defs'].shape[1],
                                       'max_definition_length')
        monitored_vars.extend(
            [def_unk_ratio, num_definitions, max_definition_length])
        if c['def_word_gating'] == 'self_attention':
            def_gates = VariableFilter(name='def_gates')(cg)
            def_gates_min = tensor.minimum(*[x.min() for x in def_gates])
            def_gates_max = tensor.maximum(*[x.max() for x in def_gates])
            monitored_vars.extend([
                rename(def_gates_min, 'def_gates_min'),
                rename(def_gates_max, 'def_gates_max')
            ])
    text_match_ratio = TextMatchRatio(data_path=os.path.join(
        fuel.config.data_path[0], 'squad/dev-v1.1.json'),
                                      requires=[
                                          predicted_begins, predicted_ends,
                                          tensor.ltensor3('contexts_text'),
                                          tensor.lmatrix('q_ids')
                                      ],
                                      name='text_match_ratio')

    parameters = cg.get_parameter_dict()
    trained_parameters = parameters.values()
    if c['embedding_path']:
        logger.debug("Exclude  word embeddings from the trained parameters")
        trained_parameters = [
            p for p in trained_parameters if not p == qam.embeddings_var()
        ]
    if c['train_only_def_part']:
        def_reading_parameters = qam.def_reading_parameters()
        trained_parameters = [
            p for p in trained_parameters if p in def_reading_parameters
        ]

    logger.info("Cost parameters" + "\n" + pprint.pformat([
        " ".join(
            (key, str(parameters[key].get_value().shape),
             'trained' if parameters[key] in trained_parameters else 'frozen'))
        for key in sorted(parameters.keys())
    ],
                                                          width=120))

    # apply dropout to the training cost and to all the variables
    # that we monitor during training
    train_cost = cost
    train_monitored_vars = list(monitored_vars)
    if c['dropout']:
        regularized_cg = ComputationGraph([cost] + train_monitored_vars)
        # Dima: the dropout that I implemented first
        bidir_outputs, = VariableFilter(bricks=[Bidirectional],
                                        roles=[OUTPUT])(cg)
        readout_layers = VariableFilter(bricks=[Rectifier], roles=[OUTPUT])(cg)
        dropout_vars = [bidir_outputs] + readout_layers
        logger.debug("applying dropout to {}".format(", ".join(
            [v.name for v in dropout_vars])))
        regularized_cg = apply_dropout(regularized_cg, dropout_vars,
                                       c['dropout'])
        # a new dropout with exactly same mask at different steps
        emb_vars = VariableFilter(roles=[EMBEDDINGS])(regularized_cg)
        emb_dropout_mask = get_dropout_mask(emb_vars[0], c['emb_dropout'])
        if c['emb_dropout_type'] == 'same_mask':
            regularized_cg = apply_dropout2(regularized_cg,
                                            emb_vars,
                                            c['emb_dropout'],
                                            dropout_mask=emb_dropout_mask)
        elif c['emb_dropout_type'] == 'regular':
            regularized_cg = apply_dropout(regularized_cg, emb_vars,
                                           c['emb_dropout'])
        else:
            raise ValueError("unknown dropout type {}".format(
                c['emb_dropout_type']))
        train_cost = regularized_cg.outputs[0]
        train_monitored_vars = regularized_cg.outputs[1:]

    rules = []
    if c['grad_clip_threshold']:
        rules.append(StepClipping(c['grad_clip_threshold']))
    rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum']))
    algorithm = GradientDescent(cost=train_cost,
                                parameters=trained_parameters,
                                step_rule=CompositeRule(rules))

    if c['grad_clip_threshold']:
        train_monitored_vars.append(algorithm.total_gradient_norm)
    if c['monitor_parameters']:
        train_monitored_vars.extend(parameter_stats(parameters, algorithm))

    training_stream = data.get_stream('train',
                                      batch_size=c['batch_size'],
                                      shuffle=True,
                                      max_length=c['max_length'])
    original_training_stream = training_stream
    if fuel_server:
        # the port will be configured by the StartFuelServer extension
        training_stream = ServerDataStream(
            sources=training_stream.sources,
            produces_examples=training_stream.produces_examples)

    extensions = [
        LoadNoUnpickling(tar_path, load_iteration_state=True,
                         load_log=True).set_conditions(
                             before_training=not new_training_job),
        StartFuelServer(original_training_stream,
                        os.path.join(save_path, 'stream.pkl'),
                        before_training=fuel_server),
        Timing(every_n_batches=c['mon_freq_train']),
        TrainingDataMonitoring(train_monitored_vars,
                               prefix="train",
                               every_n_batches=c['mon_freq_train']),
    ]
    validation = DataStreamMonitoring(
        [text_match_ratio] + monitored_vars,
        data.get_stream('dev',
                        batch_size=c['batch_size_valid'],
                        raw_text=True,
                        q_ids=True),
        prefix="dev").set_conditions(before_training=not fast_start,
                                     after_epoch=True)
    dump_predictions = DumpPredictions(save_path,
                                       text_match_ratio,
                                       before_training=not fast_start,
                                       after_epoch=True)
    track_the_best_exact = TrackTheBest(
        validation.record_name(exact_match_ratio),
        choose_best=max).set_conditions(before_training=True, after_epoch=True)
    track_the_best_text = TrackTheBest(
        validation.record_name(text_match_ratio),
        choose_best=max).set_conditions(before_training=True, after_epoch=True)
    extensions.extend([
        validation, dump_predictions, track_the_best_exact, track_the_best_text
    ])
    # We often use pretrained word embeddings and we don't want
    # to load and save them every time. To avoid that, we use
    # save_main_loop=False, we only save the trained parameters,
    # and we save the log and the iterations state separately
    # in the tar file.
    extensions.extend([
        Checkpoint(tar_path,
                   parameters=trained_parameters,
                   save_main_loop=False,
                   save_separately=['log', 'iteration_state'],
                   before_training=not fast_start,
                   every_n_epochs=c['save_freq_epochs'],
                   every_n_batches=c['save_freq_batches'],
                   after_training=not fast_start).add_condition(
                       ['after_batch', 'after_epoch'],
                       OnLogRecord(track_the_best_text.notification_name),
                       (best_tar_path, )),
        DumpTensorflowSummaries(save_path,
                                after_epoch=True,
                                every_n_batches=c['mon_freq_train'],
                                after_training=True),
        RetrievalPrintStats(retrieval=data._retrieval,
                            every_n_batches=c['mon_freq_train'],
                            before_training=not fast_start),
        Printing(after_epoch=True, every_n_batches=c['mon_freq_train']),
        FinishAfter(after_n_batches=c['n_batches'],
                    after_n_epochs=c['n_epochs']),
        Annealing(c['annealing_learning_rate'],
                  after_n_epochs=c['annealing_start_epoch']),
        LoadNoUnpickling(best_tar_path,
                         after_n_epochs=c['annealing_start_epoch'])
    ])

    main_loop = MainLoop(algorithm,
                         training_stream,
                         model=Model(cost),
                         extensions=extensions)
    main_loop.run()
コード例 #9
0
ファイル: main.py プロジェクト: DingKe/attention-lvcsr
def initialize_all(config, save_path, bokeh_name,
                   params, bokeh_server, bokeh, test_tag, use_load_ext,
                   load_log, fast_start):
    root_path, extension = os.path.splitext(save_path)

    data = Data(**config['data'])
    train_conf = config['training']
    recognizer = create_model(config, data, test_tag)

    # Separate attention_params to be handled differently
    # when regularization is applied
    attention = recognizer.generator.transition.attention
    attention_params = Selector(attention).get_parameters().values()

    logger.info(
        "Initialization schemes for all bricks.\n"
        "Works well only in my branch with __repr__ added to all them,\n"
        "there is an issue #463 in Blocks to do that properly.")

    def show_init_scheme(cur):
        result = dict()
        for attr in dir(cur):
            if attr.endswith('_init'):
                result[attr] = getattr(cur, attr)
        for child in cur.children:
            result[child.name] = show_init_scheme(child)
        return result
    logger.info(pprint.pformat(show_init_scheme(recognizer)))

    prediction, prediction_mask = add_exploration(recognizer, data, train_conf)

    #
    # Observables:
    #
    primary_observables = []  # monitored each batch
    secondary_observables = []  # monitored every 10 batches
    validation_observables = []  # monitored on the validation set

    cg = recognizer.get_cost_graph(
        batch=True, prediction=prediction, prediction_mask=prediction_mask)
    labels, = VariableFilter(
        applications=[recognizer.cost], name='labels')(cg)
    labels_mask, = VariableFilter(
        applications=[recognizer.cost], name='labels_mask')(cg)

    gain_matrix = VariableFilter(
        theano_name=RewardRegressionEmitter.GAIN_MATRIX)(cg)
    if len(gain_matrix):
        gain_matrix, = gain_matrix
        primary_observables.append(
            rename(gain_matrix.min(), 'min_gain'))
        primary_observables.append(
            rename(gain_matrix.max(), 'max_gain'))

    batch_cost = cg.outputs[0].sum()
    batch_size = rename(recognizer.labels.shape[1], "batch_size")
    # Assumes constant batch size. `aggregation.mean` is not used because
    # of Blocks #514.
    cost = batch_cost / batch_size
    cost.name = "sequence_total_cost"
    logger.info("Cost graph is built")

    # Fetch variables useful for debugging.
    # It is important not to use any aggregation schemes here,
    # as it's currently impossible to spread the effect of
    # regularization on their variables, see Blocks #514.
    cost_cg = ComputationGraph(cost)
    r = recognizer
    energies, = VariableFilter(
        applications=[r.generator.readout.readout], name="output_0")(
            cost_cg)
    bottom_output = VariableFilter(
        # We need name_regex instead of name because LookupTable calls itsoutput output_0
        applications=[r.bottom.apply], name_regex="output")(
            cost_cg)[-1]
    attended, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended")(
            cost_cg)
    attended_mask, = VariableFilter(
        applications=[r.generator.transition.apply], name="attended_mask")(
            cost_cg)
    weights, = VariableFilter(
        applications=[r.generator.evaluate], name="weights")(
            cost_cg)
    max_recording_length = rename(bottom_output.shape[0],
                                  "max_recording_length")
    # To exclude subsampling related bugs
    max_attended_mask_length = rename(attended_mask.shape[0],
                                      "max_attended_mask_length")
    max_attended_length = rename(attended.shape[0],
                                 "max_attended_length")
    max_num_phonemes = rename(labels.shape[0],
                              "max_num_phonemes")
    min_energy = rename(energies.min(), "min_energy")
    max_energy = rename(energies.max(), "max_energy")
    mean_attended = rename(abs(attended).mean(),
                           "mean_attended")
    mean_bottom_output = rename(abs(bottom_output).mean(),
                                "mean_bottom_output")
    weights_penalty = rename(monotonicity_penalty(weights, labels_mask),
                             "weights_penalty")
    weights_entropy = rename(entropy(weights, labels_mask),
                             "weights_entropy")
    mask_density = rename(labels_mask.mean(),
                          "mask_density")
    cg = ComputationGraph([
        cost, weights_penalty, weights_entropy,
        min_energy, max_energy,
        mean_attended, mean_bottom_output,
        batch_size, max_num_phonemes,
        mask_density])
    # Regularization. It is applied explicitly to all variables
    # of interest, it could not be applied to the cost only as it
    # would not have effect on auxiliary variables, see Blocks #514.
    reg_config = config.get('regularization', dict())
    regularized_cg = cg
    if reg_config.get('dropout'):
        logger.info('apply dropout')
        regularized_cg = apply_dropout(cg, [bottom_output], 0.5)
    if reg_config.get('noise'):
        logger.info('apply noise')
        noise_subjects = [p for p in cg.parameters if p not in attention_params]
        regularized_cg = apply_noise(cg, noise_subjects, reg_config['noise'])

    train_cost = regularized_cg.outputs[0]
    if reg_config.get("penalty_coof", .0) > 0:
        # big warning!!!
        # here we assume that:
        # regularized_weights_penalty = regularized_cg.outputs[1]
        train_cost = (train_cost +
                      reg_config.get("penalty_coof", .0) *
                      regularized_cg.outputs[1] / batch_size)
    if reg_config.get("decay", .0) > 0:
        train_cost = (train_cost + reg_config.get("decay", .0) *
                      l2_norm(VariableFilter(roles=[WEIGHT])(cg.parameters)) ** 2)

    train_cost = rename(train_cost, 'train_cost')

    gradients = None
    if reg_config.get('adaptive_noise'):
        logger.info('apply adaptive noise')
        if ((reg_config.get("penalty_coof", .0) > 0) or
                (reg_config.get("decay", .0) > 0)):
            logger.error('using  adaptive noise with alignment weight panalty '
                         'or weight decay is probably stupid')
        train_cost, regularized_cg, gradients, noise_brick = apply_adaptive_noise(
            cg, cg.outputs[0],
            variables=cg.parameters,
            num_examples=data.get_dataset('train').num_examples,
            parameters=Model(regularized_cg.outputs[0]).get_parameter_dict().values(),
            **reg_config.get('adaptive_noise')
        )
        train_cost.name = 'train_cost'
        adapt_noise_cg = ComputationGraph(train_cost)
        model_prior_mean = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_mean')(adapt_noise_cg)[0],
            'model_prior_mean')
        model_cost = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_cost')(adapt_noise_cg)[0],
            'model_cost')
        model_prior_variance = rename(
            VariableFilter(applications=[noise_brick.apply],
                           name='model_prior_variance')(adapt_noise_cg)[0],
            'model_prior_variance')
        regularized_cg = ComputationGraph(
            [train_cost, model_cost] +
            regularized_cg.outputs +
            [model_prior_mean, model_prior_variance])
        primary_observables += [
            regularized_cg.outputs[1],  # model cost
            regularized_cg.outputs[2],  # task cost
            regularized_cg.outputs[-2],  # model prior mean
            regularized_cg.outputs[-1]]  # model prior variance

    model = Model(train_cost)
    if params:
        logger.info("Load parameters from " + params)
        # please note: we cannot use recognizer.load_params
        # as it builds a new computation graph that dies not have
        # shapred variables added by adaptive weight noise
        with open(params, 'r') as src:
            param_values = load_parameters(src)
        model.set_parameter_values(param_values)

    parameters = model.get_parameter_dict()
    logger.info("Parameters:\n" +
                pprint.pformat(
                    [(key, parameters[key].get_value().shape) for key
                     in sorted(parameters.keys())],
                    width=120))

    # Define the training algorithm.
    clipping = StepClipping(train_conf['gradient_threshold'])
    clipping.threshold.name = "gradient_norm_threshold"
    rule_names = train_conf.get('rules', ['momentum'])
    core_rules = []
    if 'momentum' in rule_names:
        logger.info("Using scaling and momentum for training")
        core_rules.append(Momentum(train_conf['scale'], train_conf['momentum']))
    if 'adadelta' in rule_names:
        logger.info("Using AdaDelta for training")
        core_rules.append(AdaDelta(train_conf['decay_rate'], train_conf['epsilon']))
    max_norm_rules = []
    if reg_config.get('max_norm', False) > 0:
        logger.info("Apply MaxNorm")
        maxnorm_subjects = VariableFilter(roles=[WEIGHT])(cg.parameters)
        if reg_config.get('max_norm_exclude_lookup', False):
            maxnorm_subjects = [v for v in maxnorm_subjects
                                if not isinstance(get_brick(v), LookupTable)]
        logger.info("Parameters covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in parameters.items()
                                      if p in maxnorm_subjects]))
        logger.info("Parameters NOT covered by MaxNorm:\n"
                    + pprint.pformat([name for name, p in parameters.items()
                                      if not p in maxnorm_subjects]))
        max_norm_rules = [
            Restrict(VariableClipping(reg_config['max_norm'], axis=0),
                     maxnorm_subjects)]
    burn_in = []
    if train_conf.get('burn_in_steps', 0):
        burn_in.append(
            BurnIn(num_steps=train_conf['burn_in_steps']))
    algorithm = GradientDescent(
        cost=train_cost,
        parameters=parameters.values(),
        gradients=gradients,
        step_rule=CompositeRule(
            [clipping] + core_rules + max_norm_rules +
            # Parameters are not changed at all
            # when nans are encountered.
            [RemoveNotFinite(0.0)] + burn_in),
        on_unused_sources='warn')

    logger.debug("Scan Ops in the gradients")
    gradient_cg = ComputationGraph(algorithm.gradients.values())
    for op in ComputationGraph(gradient_cg).scans:
        logger.debug(op)

    # More variables for debugging: some of them can be added only
    # after the `algorithm` object is created.
    secondary_observables += list(regularized_cg.outputs)
    if not 'train_cost' in [v.name for v in secondary_observables]:
        secondary_observables += [train_cost]
    secondary_observables += [
        algorithm.total_step_norm, algorithm.total_gradient_norm,
        clipping.threshold]
    for name, param in parameters.items():
        num_elements = numpy.product(param.get_value().shape)
        norm = param.norm(2) / num_elements ** 0.5
        grad_norm = algorithm.gradients[param].norm(2) / num_elements ** 0.5
        step_norm = algorithm.steps[param].norm(2) / num_elements ** 0.5
        stats = tensor.stack(norm, grad_norm, step_norm, step_norm / grad_norm)
        stats.name = name + '_stats'
        secondary_observables.append(stats)

    primary_observables += [
        train_cost,
        algorithm.total_gradient_norm,
        algorithm.total_step_norm, clipping.threshold,
        max_recording_length,
        max_attended_length, max_attended_mask_length]

    validation_observables += [
        rename(aggregation.mean(batch_cost, batch_size), cost.name),
        rename(aggregation.sum_(batch_size), 'num_utterances'),
        weights_entropy, weights_penalty]

    def attach_aggregation_schemes(variables):
        # Aggregation specification has to be factored out as a separate
        # function as it has to be applied at the very last stage
        # separately to training and validation observables.
        result = []
        for var in variables:
            if var.name == 'weights_penalty':
                result.append(rename(aggregation.mean(var, batch_size),
                                     'weights_penalty_per_recording'))
            elif var.name == 'weights_entropy':
                result.append(rename(aggregation.mean(var, labels_mask.sum()),
                                     'weights_entropy_per_label'))
            else:
                result.append(var)
        return result

    mon_conf = config['monitoring']

    # Build main loop.
    logger.info("Initialize extensions")
    extensions = []
    if use_load_ext and params:
        extensions.append(Load(params, load_iteration_state=True, load_log=True))
    if load_log and params:
        extensions.append(LoadLog(params))
    extensions += [
        Timing(after_batch=True),
        CGStatistics(),
        #CodeVersion(['lvsr']),
    ]
    extensions.append(TrainingDataMonitoring(
        primary_observables, after_batch=True))
    average_monitoring = TrainingDataMonitoring(
        attach_aggregation_schemes(secondary_observables),
        prefix="average", every_n_batches=10)
    extensions.append(average_monitoring)
    validation = DataStreamMonitoring(
        attach_aggregation_schemes(validation_observables),
        data.get_stream("valid", shuffle=False), prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['validate_every_epochs'],
            every_n_batches=mon_conf['validate_every_batches'],
            after_training=False)
    extensions.append(validation)
    per = PhonemeErrorRate(recognizer, data,
                           **config['monitoring']['search'])
    per_monitoring = DataStreamMonitoring(
        [per], data.get_stream("valid", batches=False, shuffle=False),
        prefix="valid").set_conditions(
            before_first_epoch=not fast_start,
            every_n_epochs=mon_conf['search_every_epochs'],
            every_n_batches=mon_conf['search_every_batches'],
            after_training=False)
    extensions.append(per_monitoring)
    track_the_best_per = TrackTheBest(
        per_monitoring.record_name(per)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    track_the_best_cost = TrackTheBest(
        validation.record_name(cost)).set_conditions(
            before_first_epoch=True, after_epoch=True)
    extensions += [track_the_best_cost, track_the_best_per]
    extensions.append(AdaptiveClipping(
        algorithm.total_gradient_norm.name,
        clipping, train_conf['gradient_threshold'],
        decay_rate=0.998, burnin_period=500))
    extensions += [
        SwitchOffLengthFilter(
            data.length_filter,
            after_n_batches=train_conf.get('stop_filtering')),
        FinishAfter(after_n_batches=train_conf.get('num_batches'),
                    after_n_epochs=train_conf.get('num_epochs'))
            .add_condition(["after_batch"], _gradient_norm_is_none),
    ]
    channels = [
        # Plot 1: training and validation costs
        [average_monitoring.record_name(train_cost),
         validation.record_name(cost)],
        # Plot 2: gradient norm,
        [average_monitoring.record_name(algorithm.total_gradient_norm),
         average_monitoring.record_name(clipping.threshold)],
        # Plot 3: phoneme error rate
        [per_monitoring.record_name(per)],
        # Plot 4: training and validation mean weight entropy
        [average_monitoring._record_name('weights_entropy_per_label'),
         validation._record_name('weights_entropy_per_label')],
        # Plot 5: training and validation monotonicity penalty
        [average_monitoring._record_name('weights_penalty_per_recording'),
         validation._record_name('weights_penalty_per_recording')]]
    if bokeh:
        extensions += [
            Plot(bokeh_name if bokeh_name
                 else os.path.basename(save_path),
                 channels,
                 every_n_batches=10,
                 server_url=bokeh_server),]
    extensions += [
        Checkpoint(save_path,
                   before_first_epoch=not fast_start, after_epoch=True,
                   every_n_batches=train_conf.get('save_every_n_batches'),
                   save_separately=["model", "log"],
                   use_cpickle=True)
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_per.notification_name),
            (root_path + "_best" + extension,))
        .add_condition(
            ['after_epoch'],
            OnLogRecord(track_the_best_cost.notification_name),
            (root_path + "_best_ll" + extension,)),
        ProgressBar()]
    extensions.append(EmbedIPython(use_main_loop_run_caller_env=True))
    if config['net']['criterion']['name'].startswith('mse'):
        extensions.append(
            LogInputsGains(
                labels, cg, recognizer.generator.readout.emitter, data))

    if train_conf.get('patience'):
        patience_conf = train_conf['patience']
        if not patience_conf.get('notification_names'):
            # setdefault will not work for empty list
            patience_conf['notification_names'] = [
                track_the_best_per.notification_name,
                track_the_best_cost.notification_name]
        extensions.append(Patience(**patience_conf))

    extensions.append(Printing(every_n_batches=1,
                               attribute_filter=PrintingFilterList()))

    return model, algorithm, data, extensions
コード例 #10
0
def train_model(new_training_job, config, save_path, params, fast_start,
                fuel_server, seed):
    c = config
    if seed:
        fuel.config.default_seed = seed
        blocks.config.config.default_seed = seed

    data, model = initialize_data_and_model(config, train_phase=True)

    # full main loop can be saved...
    main_loop_path = os.path.join(save_path, 'main_loop.tar')
    # or only state (log + params) which can be useful not to pickle embeddings
    state_path = os.path.join(save_path, 'training_state.tar')
    stream_path = os.path.join(save_path, 'stream.pkl')
    best_tar_path = os.path.join(save_path, "best_model.tar")

    keys = tensor.lmatrix('keys')
    n_identical_keys = tensor.lvector('n_identical_keys')
    words = tensor.ltensor3('words')
    words_mask = tensor.matrix('words_mask')
    if theano.config.compute_test_value != 'off':
        #TODO
        test_value_data = next(
            data.get_stream('train', batch_size=4,
                            max_length=5).get_epoch_iterator())
        words.tag.test_value = test_value_data[0]
        words_mask.tag.test_value = test_value_data[1]

    if use_keys(c) and use_n_identical_keys(c):
        costs = model.apply(words,
                            words_mask,
                            keys,
                            n_identical_keys,
                            train_phase=True)
    elif use_keys(c):
        costs = model.apply(words, words_mask, keys, train_phase=True)
    else:
        costs = model.apply(words, words_mask, train_phase=True)
    cost = rename(costs.mean(), 'mean_cost')

    cg = Model(cost)
    if params:
        logger.debug("Load parameters from {}".format(params))
        with open(params) as src:
            cg.set_parameter_values(load_parameters(src))

    length = rename(words.shape[1], 'length')
    perplexity, = VariableFilter(name='perplexity')(cg)
    monitored_vars = [length, cost, perplexity]
    if c['proximity_coef']:
        proximity_term, = VariableFilter(name='proximity_term')(cg)
        monitored_vars.append(proximity_term)

    print "inputs of the model:", cg.inputs

    parameters = cg.get_parameter_dict()
    trained_parameters = parameters.values()
    saved_parameters = parameters.values()
    if c['embedding_path']:
        if c['freeze_pretrained']:
            logger.debug(
                "Exclude pretrained encoder embeddings from the trained parameters"
            )
            to_freeze = 'main'
        elif c['provide_targets']:
            logger.debug(
                "Exclude pretrained targets from the trained parameters")
            to_freeze = 'target'
        trained_parameters = [
            p for p in trained_parameters
            if not p == model.get_def_embeddings_params(to_freeze)
        ]
        saved_parameters = [
            p for p in saved_parameters
            if not p == model.get_def_embeddings_params(to_freeze)
        ]

    logger.info("Cost parameters" + "\n" + pprint.pformat([
        " ".join(
            (key, str(parameters[key].get_value().shape),
             'trained' if parameters[key] in trained_parameters else 'frozen'))
        for key in sorted(parameters.keys())
    ],
                                                          width=120))

    rules = []
    if c['grad_clip_threshold']:
        rules.append(StepClipping(c['grad_clip_threshold']))
    rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum']))
    algorithm = GradientDescent(cost=cost,
                                parameters=trained_parameters,
                                step_rule=CompositeRule(rules))

    train_monitored_vars = list(monitored_vars)
    if c['grad_clip_threshold']:
        train_monitored_vars.append(algorithm.total_gradient_norm)

    if c['monitor_parameters']:
        train_monitored_vars.extend(parameter_stats(parameters, algorithm))

    # We use a completely random seed on purpose. With Fuel server
    # it's currently not possible to restore the state of the training
    # stream. That's why it's probably better to just have it stateless.
    stream_seed = numpy.random.randint(0, 10000000) if fuel_server else None
    training_stream = data.get_stream(
        'train',
        batch_size=c['batch_size'],
        max_length=c['max_length'],
        seed=stream_seed,
        remove_keys=not use_keys(c),
        remove_n_identical_keys=not use_n_identical_keys(c))
    print "trainin_stream will contains sources:", training_stream.sources

    original_training_stream = training_stream
    if fuel_server:
        # the port will be configured by the StartFuelServer extension
        training_stream = ServerDataStream(
            sources=training_stream.sources,
            produces_examples=training_stream.produces_examples)

    validate = c['mon_freq_valid'] > 0

    if validate:
        valid_stream = data.get_stream(
            'valid',
            batch_size=c['batch_size_valid'],
            max_length=c['max_length'],
            seed=stream_seed,
            remove_keys=not use_keys(c),
            remove_n_identical_keys=not use_n_identical_keys(c))
        validation = DataStreamMonitoring(
            monitored_vars, valid_stream,
            prefix="valid").set_conditions(before_first_epoch=not fast_start,
                                           on_resumption=True,
                                           every_n_batches=c['mon_freq_valid'])
        track_the_best = TrackTheBest(validation.record_name(cost),
                                      choose_best=min).set_conditions(
                                          on_resumption=True,
                                          after_epoch=True,
                                          every_n_batches=c['mon_freq_valid'])

    # don't save them the entire main loop to avoid pickling everything
    if c['fast_checkpoint']:
        cp_path = state_path
        load = (LoadNoUnpickling(cp_path,
                                 load_iteration_state=True,
                                 load_log=True).set_conditions(
                                     before_training=not new_training_job))
        cp_args = {
            'save_main_loop': False,
            'save_separately': ['log', 'iteration_state'],
            'parameters': saved_parameters
        }

    else:
        cp_path = main_loop_path
        load = (Load(cp_path, load_iteration_state=True,
                     load_log=True).set_conditions(
                         before_training=not new_training_job))
        cp_args = {
            'save_separately': ['iteration_state'],
            'parameters': saved_parameters
        }

    checkpoint = Checkpoint(cp_path,
                            before_training=not fast_start,
                            every_n_batches=c['save_freq_batches'],
                            after_training=not fast_start,
                            **cp_args)

    if c['checkpoint_every_n_batches'] > 0 or c[
            'checkpoint_every_n_epochs'] > 0:
        intermediate_cp = IntermediateCheckpoint(
            cp_path,
            every_n_epochs=c['checkpoint_every_n_epochs'],
            every_n_batches=c['checkpoint_every_n_batches'],
            after_training=False,
            **cp_args)

    if validate:
        checkpoint = checkpoint.add_condition(
            ['after_batch', 'after_epoch'],
            OnLogRecord(track_the_best.notification_name), (best_tar_path, ))

    extensions = [
        load,
        StartFuelServer(original_training_stream,
                        stream_path,
                        before_training=fuel_server),
        Timing(every_n_batches=c['mon_freq_train'])
    ]

    extensions.extend([
        TrainingDataMonitoring(train_monitored_vars,
                               prefix="train",
                               every_n_batches=c['mon_freq_train']),
    ])
    if validate:
        extensions.extend([validation, track_the_best])

    extensions.append(checkpoint)
    if c['checkpoint_every_n_batches'] > 0 or c[
            'checkpoint_every_n_epochs'] > 0:
        extensions.append(intermediate_cp)
    extensions.extend(
        [Printing(on_resumption=True, every_n_batches=c['mon_freq_train'])])

    if validate and c['n_valid_early'] > 0:
        extensions.append(
            FinishIfNoImprovementAfter(track_the_best.notification_name,
                                       iterations=c['n_valid_early'] *
                                       c['mon_freq_valid'],
                                       every_n_batches=c['mon_freq_valid']))
    extensions.append(FinishAfter(after_n_epochs=c['n_epochs']))

    logger.info("monitored variables during training:" + "\n" +
                pprint.pformat(train_monitored_vars, width=120))
    logger.info("monitored variables during valid:" + "\n" +
                pprint.pformat(monitored_vars, width=120))

    main_loop = MainLoop(algorithm,
                         training_stream,
                         model=Model(cost),
                         extensions=extensions)

    main_loop.run()