def train_extractive_qa(new_training_job, config, save_path, params, fast_start, fuel_server, seed): if seed: fuel.config.default_seed = seed blocks.config.config.default_seed = seed root_path = os.path.join(save_path, 'training_state') extension = '.tar' tar_path = root_path + extension best_tar_path = root_path + '_best' + extension c = config data, qam = initialize_data_and_model(c) if theano.config.compute_test_value != 'off': test_value_data = next( data.get_stream('train', shuffle=True, batch_size=4, max_length=5).get_epoch_iterator(as_dict=True)) for var in qam.input_vars.values(): var.tag.test_value = test_value_data[var.name] costs = qam.apply_with_default_vars() cost = rename(costs.mean(), 'mean_cost') cg = Model(cost) if params: logger.debug("Load parameters from {}".format(params)) with open(params) as src: cg.set_parameter_values(load_parameters(src)) length = rename(qam.contexts.shape[1], 'length') batch_size = rename(qam.contexts.shape[0], 'batch_size') predicted_begins, = VariableFilter(name='predicted_begins')(cg) predicted_ends, = VariableFilter(name='predicted_ends')(cg) exact_match, = VariableFilter(name='exact_match')(cg) exact_match_ratio = rename(exact_match.mean(), 'exact_match_ratio') context_unk_ratio, = VariableFilter(name='context_unk_ratio')(cg) monitored_vars = [ length, batch_size, cost, exact_match_ratio, context_unk_ratio ] if c['dict_path']: def_unk_ratio, = VariableFilter(name='def_unk_ratio')(cg) num_definitions = rename(qam.input_vars['defs'].shape[0], 'num_definitions') max_definition_length = rename(qam.input_vars['defs'].shape[1], 'max_definition_length') monitored_vars.extend( [def_unk_ratio, num_definitions, max_definition_length]) if c['def_word_gating'] == 'self_attention': def_gates = VariableFilter(name='def_gates')(cg) def_gates_min = tensor.minimum(*[x.min() for x in def_gates]) def_gates_max = tensor.maximum(*[x.max() for x in def_gates]) monitored_vars.extend([ rename(def_gates_min, 'def_gates_min'), rename(def_gates_max, 'def_gates_max') ]) text_match_ratio = TextMatchRatio(data_path=os.path.join( fuel.config.data_path[0], 'squad/dev-v1.1.json'), requires=[ predicted_begins, predicted_ends, tensor.ltensor3('contexts_text'), tensor.lmatrix('q_ids') ], name='text_match_ratio') parameters = cg.get_parameter_dict() trained_parameters = parameters.values() if c['embedding_path']: logger.debug("Exclude word embeddings from the trained parameters") trained_parameters = [ p for p in trained_parameters if not p == qam.embeddings_var() ] if c['train_only_def_part']: def_reading_parameters = qam.def_reading_parameters() trained_parameters = [ p for p in trained_parameters if p in def_reading_parameters ] logger.info("Cost parameters" + "\n" + pprint.pformat([ " ".join( (key, str(parameters[key].get_value().shape), 'trained' if parameters[key] in trained_parameters else 'frozen')) for key in sorted(parameters.keys()) ], width=120)) # apply dropout to the training cost and to all the variables # that we monitor during training train_cost = cost train_monitored_vars = list(monitored_vars) if c['dropout']: regularized_cg = ComputationGraph([cost] + train_monitored_vars) # Dima: the dropout that I implemented first bidir_outputs, = VariableFilter(bricks=[Bidirectional], roles=[OUTPUT])(cg) readout_layers = VariableFilter(bricks=[Rectifier], roles=[OUTPUT])(cg) dropout_vars = [bidir_outputs] + readout_layers logger.debug("applying dropout to {}".format(", ".join( [v.name for v in dropout_vars]))) regularized_cg = apply_dropout(regularized_cg, dropout_vars, c['dropout']) # a new dropout with exactly same mask at different steps emb_vars = VariableFilter(roles=[EMBEDDINGS])(regularized_cg) emb_dropout_mask = get_dropout_mask(emb_vars[0], c['emb_dropout']) if c['emb_dropout_type'] == 'same_mask': regularized_cg = apply_dropout2(regularized_cg, emb_vars, c['emb_dropout'], dropout_mask=emb_dropout_mask) elif c['emb_dropout_type'] == 'regular': regularized_cg = apply_dropout(regularized_cg, emb_vars, c['emb_dropout']) else: raise ValueError("unknown dropout type {}".format( c['emb_dropout_type'])) train_cost = regularized_cg.outputs[0] train_monitored_vars = regularized_cg.outputs[1:] rules = [] if c['grad_clip_threshold']: rules.append(StepClipping(c['grad_clip_threshold'])) rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum'])) algorithm = GradientDescent(cost=train_cost, parameters=trained_parameters, step_rule=CompositeRule(rules)) if c['grad_clip_threshold']: train_monitored_vars.append(algorithm.total_gradient_norm) if c['monitor_parameters']: train_monitored_vars.extend(parameter_stats(parameters, algorithm)) training_stream = data.get_stream('train', batch_size=c['batch_size'], shuffle=True, max_length=c['max_length']) original_training_stream = training_stream if fuel_server: # the port will be configured by the StartFuelServer extension training_stream = ServerDataStream( sources=training_stream.sources, produces_examples=training_stream.produces_examples) extensions = [ LoadNoUnpickling(tar_path, load_iteration_state=True, load_log=True).set_conditions( before_training=not new_training_job), StartFuelServer(original_training_stream, os.path.join(save_path, 'stream.pkl'), before_training=fuel_server), Timing(every_n_batches=c['mon_freq_train']), TrainingDataMonitoring(train_monitored_vars, prefix="train", every_n_batches=c['mon_freq_train']), ] validation = DataStreamMonitoring( [text_match_ratio] + monitored_vars, data.get_stream('dev', batch_size=c['batch_size_valid'], raw_text=True, q_ids=True), prefix="dev").set_conditions(before_training=not fast_start, after_epoch=True) dump_predictions = DumpPredictions(save_path, text_match_ratio, before_training=not fast_start, after_epoch=True) track_the_best_exact = TrackTheBest( validation.record_name(exact_match_ratio), choose_best=max).set_conditions(before_training=True, after_epoch=True) track_the_best_text = TrackTheBest( validation.record_name(text_match_ratio), choose_best=max).set_conditions(before_training=True, after_epoch=True) extensions.extend([ validation, dump_predictions, track_the_best_exact, track_the_best_text ]) # We often use pretrained word embeddings and we don't want # to load and save them every time. To avoid that, we use # save_main_loop=False, we only save the trained parameters, # and we save the log and the iterations state separately # in the tar file. extensions.extend([ Checkpoint(tar_path, parameters=trained_parameters, save_main_loop=False, save_separately=['log', 'iteration_state'], before_training=not fast_start, every_n_epochs=c['save_freq_epochs'], every_n_batches=c['save_freq_batches'], after_training=not fast_start).add_condition( ['after_batch', 'after_epoch'], OnLogRecord(track_the_best_text.notification_name), (best_tar_path, )), DumpTensorflowSummaries(save_path, after_epoch=True, every_n_batches=c['mon_freq_train'], after_training=True), RetrievalPrintStats(retrieval=data._retrieval, every_n_batches=c['mon_freq_train'], before_training=not fast_start), Printing(after_epoch=True, every_n_batches=c['mon_freq_train']), FinishAfter(after_n_batches=c['n_batches'], after_n_epochs=c['n_epochs']), Annealing(c['annealing_learning_rate'], after_n_epochs=c['annealing_start_epoch']), LoadNoUnpickling(best_tar_path, after_n_epochs=c['annealing_start_epoch']) ]) main_loop = MainLoop(algorithm, training_stream, model=Model(cost), extensions=extensions) main_loop.run()
def train_language_model(new_training_job, config, save_path, params, fast_start, fuel_server, seed): c = config if seed: fuel.config.default_seed = seed blocks.config.config.default_seed = seed data, lm, retrieval = initialize_data_and_model(config) # full main loop can be saved... main_loop_path = os.path.join(save_path, 'main_loop.tar') # or only state (log + params) which can be useful not to pickle embeddings state_path = os.path.join(save_path, 'training_state.tar') stream_path = os.path.join(save_path, 'stream.pkl') best_tar_path = os.path.join(save_path, "best_model.tar") words = tensor.ltensor3('words') words_mask = tensor.matrix('words_mask') if theano.config.compute_test_value != 'off': test_value_data = next( data.get_stream('train', batch_size=4, max_length=5).get_epoch_iterator()) words.tag.test_value = test_value_data[0] words_mask.tag.test_value = test_value_data[1] costs, updates = lm.apply(words, words_mask) cost = rename(costs.mean(), 'mean_cost') cg = Model(cost) if params: logger.debug("Load parameters from {}".format(params)) with open(params) as src: cg.set_parameter_values(load_parameters(src)) length = rename(words.shape[1], 'length') perplexity, = VariableFilter(name='perplexity')(cg) perplexities = VariableFilter(name_regex='perplexity.*')(cg) monitored_vars = [length, cost] + perplexities if c['dict_path']: num_definitions, = VariableFilter(name='num_definitions')(cg) monitored_vars.extend([num_definitions]) parameters = cg.get_parameter_dict() trained_parameters = parameters.values() saved_parameters = parameters.values() if c['embedding_path']: logger.debug("Exclude word embeddings from the trained parameters") trained_parameters = [ p for p in trained_parameters if not p == lm.get_def_embeddings_params() ] saved_parameters = [ p for p in saved_parameters if not p == lm.get_def_embeddings_params() ] if c['cache_size'] != 0: logger.debug("Enable fake recursivity for looking up embeddings") trained_parameters = [ p for p in trained_parameters if not p == lm.get_cache_params() ] logger.info("Cost parameters" + "\n" + pprint.pformat([ " ".join( (key, str(parameters[key].get_value().shape), 'trained' if parameters[key] in trained_parameters else 'frozen')) for key in sorted(parameters.keys()) ], width=120)) rules = [] if c['grad_clip_threshold']: rules.append(StepClipping(c['grad_clip_threshold'])) rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum'])) algorithm = GradientDescent(cost=cost, parameters=trained_parameters, step_rule=CompositeRule(rules)) if c['cache_size'] != 0: algorithm.add_updates(updates) train_monitored_vars = list(monitored_vars) if c['grad_clip_threshold']: train_monitored_vars.append(algorithm.total_gradient_norm) word_emb_RMS, = VariableFilter(name='word_emb_RMS')(cg) main_rnn_in_RMS, = VariableFilter(name='main_rnn_in_RMS')(cg) train_monitored_vars.extend([word_emb_RMS, main_rnn_in_RMS]) if c['monitor_parameters']: train_monitored_vars.extend(parameter_stats(parameters, algorithm)) # We use a completely random seed on purpose. With Fuel server # it's currently not possible to restore the state of the training # stream. That's why it's probably better to just have it stateless. stream_seed = numpy.random.randint(0, 10000000) if fuel_server else None training_stream = data.get_stream('train', batch_size=c['batch_size'], max_length=c['max_length'], seed=stream_seed) valid_stream = data.get_stream('valid', batch_size=c['batch_size_valid'], max_length=c['max_length'], seed=stream_seed) original_training_stream = training_stream if fuel_server: # the port will be configured by the StartFuelServer extension training_stream = ServerDataStream( sources=training_stream.sources, produces_examples=training_stream.produces_examples) validation = DataStreamMonitoring(monitored_vars, valid_stream, prefix="valid").set_conditions( before_first_epoch=not fast_start, on_resumption=True, every_n_batches=c['mon_freq_valid']) track_the_best = TrackTheBest(validation.record_name(perplexity), choose_best=min).set_conditions( on_resumption=True, after_epoch=True, every_n_batches=c['mon_freq_valid']) # don't save them the entire main loop to avoid pickling everything if c['fast_checkpoint']: load = (LoadNoUnpickling(state_path, load_iteration_state=True, load_log=True).set_conditions( before_training=not new_training_job)) cp_args = { 'save_main_loop': False, 'save_separately': ['log', 'iteration_state'], 'parameters': saved_parameters } checkpoint = Checkpoint(state_path, before_training=not fast_start, every_n_batches=c['save_freq_batches'], after_training=not fast_start, **cp_args) if c['checkpoint_every_n_batches']: intermediate_cp = IntermediateCheckpoint( state_path, every_n_batches=c['checkpoint_every_n_batches'], after_training=False, **cp_args) else: load = (Load(main_loop_path, load_iteration_state=True, load_log=True).set_conditions( before_training=not new_training_job)) cp_args = { 'save_separately': ['iteration_state'], 'parameters': saved_parameters } checkpoint = Checkpoint(main_loop_path, before_training=not fast_start, every_n_batches=c['save_freq_batches'], after_training=not fast_start, **cp_args) if c['checkpoint_every_n_batches']: intermediate_cp = IntermediateCheckpoint( main_loop_path, every_n_batches=c['checkpoint_every_n_batches'], after_training=False, **cp_args) checkpoint = checkpoint.add_condition( ['after_batch', 'after_epoch'], OnLogRecord(track_the_best.notification_name), (best_tar_path, )) extensions = [ load, StartFuelServer(original_training_stream, stream_path, before_training=fuel_server), Timing(every_n_batches=c['mon_freq_train']) ] if retrieval: extensions.append( RetrievalPrintStats(retrieval=retrieval, every_n_batches=c['mon_freq_train'], before_training=not fast_start)) extensions.extend([ TrainingDataMonitoring(train_monitored_vars, prefix="train", every_n_batches=c['mon_freq_train']), validation, track_the_best, checkpoint ]) if c['checkpoint_every_n_batches']: extensions.append(intermediate_cp) extensions.extend([ DumpTensorflowSummaries(save_path, every_n_batches=c['mon_freq_train'], after_training=True), Printing(on_resumption=True, every_n_batches=c['mon_freq_train']), FinishIfNoImprovementAfter(track_the_best.notification_name, iterations=50 * c['mon_freq_valid'], every_n_batches=c['mon_freq_valid']), FinishAfter(after_n_batches=c['n_batches']) ]) logger.info("monitored variables during training:" + "\n" + pprint.pformat(train_monitored_vars, width=120)) logger.info("monitored variables during valid:" + "\n" + pprint.pformat(monitored_vars, width=120)) main_loop = MainLoop(algorithm, training_stream, model=Model(cost), extensions=extensions) main_loop.run()
def train_model(new_training_job, config, save_path, params, fast_start, fuel_server, seed): c = config if seed: fuel.config.default_seed = seed blocks.config.config.default_seed = seed data, model = initialize_data_and_model(config, train_phase=True) # full main loop can be saved... main_loop_path = os.path.join(save_path, 'main_loop.tar') # or only state (log + params) which can be useful not to pickle embeddings state_path = os.path.join(save_path, 'training_state.tar') stream_path = os.path.join(save_path, 'stream.pkl') best_tar_path = os.path.join(save_path, "best_model.tar") keys = tensor.lmatrix('keys') n_identical_keys = tensor.lvector('n_identical_keys') words = tensor.ltensor3('words') words_mask = tensor.matrix('words_mask') if theano.config.compute_test_value != 'off': #TODO test_value_data = next( data.get_stream('train', batch_size=4, max_length=5).get_epoch_iterator()) words.tag.test_value = test_value_data[0] words_mask.tag.test_value = test_value_data[1] if use_keys(c) and use_n_identical_keys(c): costs = model.apply(words, words_mask, keys, n_identical_keys, train_phase=True) elif use_keys(c): costs = model.apply(words, words_mask, keys, train_phase=True) else: costs = model.apply(words, words_mask, train_phase=True) cost = rename(costs.mean(), 'mean_cost') cg = Model(cost) if params: logger.debug("Load parameters from {}".format(params)) with open(params) as src: cg.set_parameter_values(load_parameters(src)) length = rename(words.shape[1], 'length') perplexity, = VariableFilter(name='perplexity')(cg) monitored_vars = [length, cost, perplexity] if c['proximity_coef']: proximity_term, = VariableFilter(name='proximity_term')(cg) monitored_vars.append(proximity_term) print "inputs of the model:", cg.inputs parameters = cg.get_parameter_dict() trained_parameters = parameters.values() saved_parameters = parameters.values() if c['embedding_path']: if c['freeze_pretrained']: logger.debug( "Exclude pretrained encoder embeddings from the trained parameters" ) to_freeze = 'main' elif c['provide_targets']: logger.debug( "Exclude pretrained targets from the trained parameters") to_freeze = 'target' trained_parameters = [ p for p in trained_parameters if not p == model.get_def_embeddings_params(to_freeze) ] saved_parameters = [ p for p in saved_parameters if not p == model.get_def_embeddings_params(to_freeze) ] logger.info("Cost parameters" + "\n" + pprint.pformat([ " ".join( (key, str(parameters[key].get_value().shape), 'trained' if parameters[key] in trained_parameters else 'frozen')) for key in sorted(parameters.keys()) ], width=120)) rules = [] if c['grad_clip_threshold']: rules.append(StepClipping(c['grad_clip_threshold'])) rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum'])) algorithm = GradientDescent(cost=cost, parameters=trained_parameters, step_rule=CompositeRule(rules)) train_monitored_vars = list(monitored_vars) if c['grad_clip_threshold']: train_monitored_vars.append(algorithm.total_gradient_norm) if c['monitor_parameters']: train_monitored_vars.extend(parameter_stats(parameters, algorithm)) # We use a completely random seed on purpose. With Fuel server # it's currently not possible to restore the state of the training # stream. That's why it's probably better to just have it stateless. stream_seed = numpy.random.randint(0, 10000000) if fuel_server else None training_stream = data.get_stream( 'train', batch_size=c['batch_size'], max_length=c['max_length'], seed=stream_seed, remove_keys=not use_keys(c), remove_n_identical_keys=not use_n_identical_keys(c)) print "trainin_stream will contains sources:", training_stream.sources original_training_stream = training_stream if fuel_server: # the port will be configured by the StartFuelServer extension training_stream = ServerDataStream( sources=training_stream.sources, produces_examples=training_stream.produces_examples) validate = c['mon_freq_valid'] > 0 if validate: valid_stream = data.get_stream( 'valid', batch_size=c['batch_size_valid'], max_length=c['max_length'], seed=stream_seed, remove_keys=not use_keys(c), remove_n_identical_keys=not use_n_identical_keys(c)) validation = DataStreamMonitoring( monitored_vars, valid_stream, prefix="valid").set_conditions(before_first_epoch=not fast_start, on_resumption=True, every_n_batches=c['mon_freq_valid']) track_the_best = TrackTheBest(validation.record_name(cost), choose_best=min).set_conditions( on_resumption=True, after_epoch=True, every_n_batches=c['mon_freq_valid']) # don't save them the entire main loop to avoid pickling everything if c['fast_checkpoint']: cp_path = state_path load = (LoadNoUnpickling(cp_path, load_iteration_state=True, load_log=True).set_conditions( before_training=not new_training_job)) cp_args = { 'save_main_loop': False, 'save_separately': ['log', 'iteration_state'], 'parameters': saved_parameters } else: cp_path = main_loop_path load = (Load(cp_path, load_iteration_state=True, load_log=True).set_conditions( before_training=not new_training_job)) cp_args = { 'save_separately': ['iteration_state'], 'parameters': saved_parameters } checkpoint = Checkpoint(cp_path, before_training=not fast_start, every_n_batches=c['save_freq_batches'], after_training=not fast_start, **cp_args) if c['checkpoint_every_n_batches'] > 0 or c[ 'checkpoint_every_n_epochs'] > 0: intermediate_cp = IntermediateCheckpoint( cp_path, every_n_epochs=c['checkpoint_every_n_epochs'], every_n_batches=c['checkpoint_every_n_batches'], after_training=False, **cp_args) if validate: checkpoint = checkpoint.add_condition( ['after_batch', 'after_epoch'], OnLogRecord(track_the_best.notification_name), (best_tar_path, )) extensions = [ load, StartFuelServer(original_training_stream, stream_path, before_training=fuel_server), Timing(every_n_batches=c['mon_freq_train']) ] extensions.extend([ TrainingDataMonitoring(train_monitored_vars, prefix="train", every_n_batches=c['mon_freq_train']), ]) if validate: extensions.extend([validation, track_the_best]) extensions.append(checkpoint) if c['checkpoint_every_n_batches'] > 0 or c[ 'checkpoint_every_n_epochs'] > 0: extensions.append(intermediate_cp) extensions.extend( [Printing(on_resumption=True, every_n_batches=c['mon_freq_train'])]) if validate and c['n_valid_early'] > 0: extensions.append( FinishIfNoImprovementAfter(track_the_best.notification_name, iterations=c['n_valid_early'] * c['mon_freq_valid'], every_n_batches=c['mon_freq_valid'])) extensions.append(FinishAfter(after_n_epochs=c['n_epochs'])) logger.info("monitored variables during training:" + "\n" + pprint.pformat(train_monitored_vars, width=120)) logger.info("monitored variables during valid:" + "\n" + pprint.pformat(monitored_vars, width=120)) main_loop = MainLoop(algorithm, training_stream, model=Model(cost), extensions=extensions) main_loop.run()