def train(model, data, params): """ Trains a model. Inputs: model (ATISModel): The model to train. data (ATISData): The data that is used to train. params (namespace): Training parameters. """ # Get the training batches. log = Logger(os.path.join(params.logdir, params.logfile), "w") num_train_original = atis_data.num_utterances(data.train_data) log.put("Original number of training utterances:\t" + str(num_train_original)) eval_fn = evaluate_utterance_sample trainbatch_fn = data.get_utterance_batches trainsample_fn = data.get_random_utterances validsample_fn = data.get_all_utterances batch_size = params.batch_size if params.interaction_level: batch_size = 1 eval_fn = evaluate_interaction_sample trainbatch_fn = data.get_interaction_batches trainsample_fn = data.get_random_interactions validsample_fn = data.get_all_interactions maximum_output_length = params.train_maximum_sql_length train_batches = trainbatch_fn(batch_size, max_output_length=maximum_output_length, randomize=not params.deterministic) if params.num_train >= 0: train_batches = train_batches[:params.num_train] training_sample = trainsample_fn(params.train_evaluation_size, max_output_length=maximum_output_length) valid_examples = validsample_fn(data.valid_data, max_output_length=maximum_output_length) num_train_examples = sum([len(batch) for batch in train_batches]) num_steps_per_epoch = len(train_batches) log.put( "Actual number of used training examples:\t" + str(num_train_examples)) log.put("(Shortened by output limit of " + str(maximum_output_length) + ")") log.put("Number of steps per epoch:\t" + str(num_steps_per_epoch)) log.put("Batch size:\t" + str(batch_size)) print( "Kept " + str(num_train_examples) + "/" + str(num_train_original) + " examples") print( "Batch size of " + str(batch_size) + " gives " + str(num_steps_per_epoch) + " steps per epoch") # Keeping track of things during training. epochs = 0 patience = params.initial_patience learning_rate_coefficient = 1. previous_epoch_loss = float('inf') maximum_validation_accuracy = 0. maximum_string_accuracy = 0. countdown = int(patience) if params.scheduler: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model.trainer, mode='min', ) keep_training = True while keep_training: log.put("Epoch:\t" + str(epochs)) model.set_dropout(params.dropout_amount) if not params.scheduler: model.set_learning_rate(learning_rate_coefficient * params.initial_learning_rate) # Run a training step. if params.interaction_level: epoch_loss = train_epoch_with_interactions( train_batches, params, model, randomize=not params.deterministic) else: epoch_loss = train_epoch_with_utterances( train_batches, model, randomize=not params.deterministic) log.put("train epoch loss:\t" + str(epoch_loss)) model.set_dropout(0.) # Run an evaluation step on a sample of the training data. train_eval_results = eval_fn(training_sample, model, params.train_maximum_sql_length, name=os.path.join(params.logdir, "train-eval"), write_results=True, gold_forcing=True, metrics=TRAIN_EVAL_METRICS)[0] for name, value in train_eval_results.items(): log.put( "train final gold-passing " + name.name + ":\t" + "%.2f" % value) # Run an evaluation step on the validation set. valid_eval_results = eval_fn(valid_examples, model, params.eval_maximum_sql_length, name=os.path.join(params.logdir, "valid-eval"), write_results=True, gold_forcing=True, metrics=VALID_EVAL_METRICS)[0] for name, value in valid_eval_results.items(): log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value) valid_loss = valid_eval_results[Metrics.LOSS] valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY] string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY] if params.scheduler: scheduler.step(valid_loss) if valid_loss > previous_epoch_loss: learning_rate_coefficient *= params.learning_rate_ratio log.put( "learning rate coefficient:\t" + str(learning_rate_coefficient)) previous_epoch_loss = valid_loss saved = False if not saved and string_accuracy > maximum_string_accuracy: maximum_string_accuracy = string_accuracy patience = patience * params.patience_ratio countdown = int(patience) last_save_file = os.path.join(params.logdir, "save_" + str(epochs)) model.save(last_save_file) log.put( "maximum string accuracy:\t" + str(maximum_string_accuracy)) log.put("patience:\t" + str(patience)) log.put("save file:\t" + str(last_save_file)) if countdown <= 0: keep_training = False countdown -= 1 log.put("countdown:\t" + str(countdown)) log.put("") epochs += 1 log.put("Finished training!") log.close() return last_save_file
def train(model, data, params, last_save_file = None): """ Trains a model. Inputs: model (ATISModel): The model to train. data (ATISData): The data that is used to train. params (namespace): Training parameters. """ # Get the training batches. if last_save_file: model.load(last_save_file) log = Logger(os.path.join(params.logdir, params.logfile), "w") num_train_original = atis_data.num_utterances(data.train_data) log.put("Original number of training utterances:\t" + str(num_train_original)) eval_fn = evaluate_utterance_sample trainbatch_fn = data.get_utterance_batches trainsample_fn = data.get_random_utterances validsample_fn = data.get_all_utterances batch_size = params.batch_size if params.interaction_level: batch_size = 1 eval_fn = evaluate_interaction_sample trainbatch_fn = data.get_interaction_batches trainsample_fn = data.get_random_interactions validsample_fn = data.get_all_interactions maximum_output_length = params.train_maximum_sql_length train_batches = trainbatch_fn(batch_size, max_output_length=maximum_output_length, randomize=not params.deterministic) if params.num_train >= 0: train_batches = train_batches[:params.num_train] training_sample = trainsample_fn(params.train_evaluation_size, max_output_length=maximum_output_length) valid_examples = validsample_fn(data.valid_data, max_output_length=maximum_output_length) num_train_examples = sum([len(batch) for batch in train_batches]) num_steps_per_epoch = len(train_batches) log.put( "Actual number of used training examples:\t" + str(num_train_examples)) log.put("(Shortened by output limit of " + str(maximum_output_length) + ")") log.put("Number of steps per epoch:\t" + str(num_steps_per_epoch)) log.put("Batch size:\t" + str(batch_size)) print( "Kept " + str(num_train_examples) + "/" + str(num_train_original) + " examples") print( "Batch size of " + str(batch_size) + " gives " + str(num_steps_per_epoch) + " steps per epoch") # Keeping track of things during training. epochs = 0 patience = params.initial_patience learning_rate_coefficient = 1. previous_epoch_loss = float('inf') maximum_validation_accuracy = 0. maximum_string_accuracy = 0. #crayon = CrayonClient(hostname="localhost") #experiment = crayon.create_experiment(params.logdir) countdown = int(patience) keep_training = True while keep_training: log.put("Epoch:\t" + str(epochs)) model.set_dropout(params.dropout_amount) model.set_learning_rate( learning_rate_coefficient * params.initial_learning_rate) # Run a training step. if params.interaction_level: epoch_loss = train_epoch_with_interactions( train_batches, params, model, randomize=not params.deterministic) else: epoch_loss = train_epoch_with_utterances( train_batches, model, randomize=not params.deterministic) log.put("train epoch loss:\t" + str(epoch_loss)) #experiment.add_scalar_value("train_loss", epoch_loss, step=epochs) model.set_dropout(0.) # Run an evaluation step on a sample of the training data. train_eval_results = eval_fn(training_sample, model, params.train_maximum_sql_length, "evals/train-eval", gold_forcing=True, metrics=TRAIN_EVAL_METRICS, write_results=True)[0] for name, value in train_eval_results.items(): log.put( "train final gold-passing " + name.name + ":\t" + "%.2f" % value) #experiment.add_scalar_value( # "train_gold_" + name.name, value, step=epochs) # Run an evaluation step on the validation set. if params.new_version: suffix = "-new" else: suffix = "-old" valid_eval_results = eval_fn(valid_examples, model, params.train_maximum_sql_length, "evals/valid-eval" + suffix + str(epochs), gold_forcing=True, database_username=params.database_username, database_password=params.database_password, database_timeout=params.database_timeout, metrics=VALID_EVAL_METRICS_WITHOUT_MYSQL, write_results=True)[0] for name, value in valid_eval_results.items(): log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value) #experiment.add_scalar_value( # "valid_gold_" + name.name, value, step=epochs) valid_loss = valid_eval_results[Metrics.LOSS] valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY] string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY] if valid_loss > previous_epoch_loss: learning_rate_coefficient *= params.learning_rate_ratio log.put( "learning rate coefficient:\t" + str(learning_rate_coefficient)) #experiment.add_scalar_value( # "learning_rate", # learning_rate_coefficient, # step=epochs) previous_epoch_loss = valid_loss saved = False if valid_token_accuracy > maximum_validation_accuracy: saved = True maximum_validation_accuracy = valid_token_accuracy patience = patience * params.patience_ratio countdown = int(patience) last_save_file = os.path.join(params.logdir, "save_" + str(epochs)) model.save(last_save_file) log.put("maximum accuracy:\t" + str(maximum_validation_accuracy)) log.put("patience:\t" + str(patience)) log.put("save file:\t" + str(last_save_file)) if not saved and string_accuracy > maximum_string_accuracy: maximum_string_accuracy = string_accuracy log.put( "maximum string accuracy:\t" + str(maximum_string_accuracy)) save_path = "save_" + str(epochs) if params.interaction_level: save_path += "-interaction" else: save_path += "-utterance" last_save_file = os.path.join(params.logdir, "save_" + str(epochs)) model.save(last_save_file) send_slack_message( username=params.logdir, message="Epoch " + str(epochs) + ": " + str(string_accuracy) + " validation accuracy; countdown is " + str(countdown), channel="models") if countdown <= 0: keep_training = False countdown -= 1 log.put("countdown:\t" + str(countdown)) #experiment.add_scalar_value("countdown", countdown, step=epochs) log.put("") epochs += 1 log.put("Finished training!") send_slack_message(username=params.logdir, message="Done training!!", channel="@alsuhr") log.close() return last_save_file