Ejemplo n.º 1
0
def train(model, data, params):
    """ Trains a model.

    Inputs:
        model (ATISModel): The model to train.
        data (ATISData): The data that is used to train.
        params (namespace): Training parameters.
    """
    # Get the training batches.
    log = Logger(os.path.join(params.logdir, params.logfile), "w")
    num_train_original = atis_data.num_utterances(data.train_data)
    log.put("Original number of training utterances:\t"
            + str(num_train_original))

    eval_fn = evaluate_utterance_sample
    trainbatch_fn = data.get_utterance_batches
    trainsample_fn = data.get_random_utterances
    validsample_fn = data.get_all_utterances
    batch_size = params.batch_size
    if params.interaction_level:
        batch_size = 1
        eval_fn = evaluate_interaction_sample
        trainbatch_fn = data.get_interaction_batches
        trainsample_fn = data.get_random_interactions
        validsample_fn = data.get_all_interactions

    maximum_output_length = params.train_maximum_sql_length
    train_batches = trainbatch_fn(batch_size,
                                  max_output_length=maximum_output_length,
                                  randomize=not params.deterministic)

    if params.num_train >= 0:
        train_batches = train_batches[:params.num_train]

    training_sample = trainsample_fn(params.train_evaluation_size,
                                     max_output_length=maximum_output_length)
    valid_examples = validsample_fn(data.valid_data,
                                    max_output_length=maximum_output_length)

    num_train_examples = sum([len(batch) for batch in train_batches])
    num_steps_per_epoch = len(train_batches)

    log.put(
        "Actual number of used training examples:\t" +
        str(num_train_examples))
    log.put("(Shortened by output limit of " +
            str(maximum_output_length) +
            ")")
    log.put("Number of steps per epoch:\t" + str(num_steps_per_epoch))
    log.put("Batch size:\t" + str(batch_size))

    print(
        "Kept " +
        str(num_train_examples) +
        "/" +
        str(num_train_original) +
        " examples")
    print(
        "Batch size of " +
        str(batch_size) +
        " gives " +
        str(num_steps_per_epoch) +
        " steps per epoch")

    # Keeping track of things during training.
    epochs = 0
    patience = params.initial_patience
    learning_rate_coefficient = 1.
    previous_epoch_loss = float('inf')
    maximum_validation_accuracy = 0.
    maximum_string_accuracy = 0.

    countdown = int(patience)

    if params.scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model.trainer, mode='min', )

    keep_training = True
    while keep_training:
        log.put("Epoch:\t" + str(epochs))
        model.set_dropout(params.dropout_amount)

        if not params.scheduler:
            model.set_learning_rate(learning_rate_coefficient * params.initial_learning_rate)

        # Run a training step.
        if params.interaction_level:
            epoch_loss = train_epoch_with_interactions(
                train_batches,
                params,
                model,
                randomize=not params.deterministic)
        else:
            epoch_loss = train_epoch_with_utterances(
                train_batches,
                model,
                randomize=not params.deterministic)

        log.put("train epoch loss:\t" + str(epoch_loss))

        model.set_dropout(0.)

        # Run an evaluation step on a sample of the training data.
        train_eval_results = eval_fn(training_sample,
                                     model,
                                     params.train_maximum_sql_length,
                                     name=os.path.join(params.logdir, "train-eval"),
                                     write_results=True,
                                     gold_forcing=True,
                                     metrics=TRAIN_EVAL_METRICS)[0]

        for name, value in train_eval_results.items():
            log.put(
                "train final gold-passing " +
                name.name +
                ":\t" +
                "%.2f" %
                value)

        # Run an evaluation step on the validation set.
        valid_eval_results = eval_fn(valid_examples,
                                     model,
                                     params.eval_maximum_sql_length,
                                     name=os.path.join(params.logdir, "valid-eval"),
                                     write_results=True,
                                     gold_forcing=True,
                                     metrics=VALID_EVAL_METRICS)[0]
        for name, value in valid_eval_results.items():
            log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value)

        valid_loss = valid_eval_results[Metrics.LOSS]
        valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY]
        string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY]

        if params.scheduler:
            scheduler.step(valid_loss)

        if valid_loss > previous_epoch_loss:
            learning_rate_coefficient *= params.learning_rate_ratio
            log.put(
                "learning rate coefficient:\t" +
                str(learning_rate_coefficient))

        previous_epoch_loss = valid_loss
        saved = False
        
        if not saved and string_accuracy > maximum_string_accuracy:
            maximum_string_accuracy = string_accuracy
            patience = patience * params.patience_ratio
            countdown = int(patience)
            last_save_file = os.path.join(params.logdir, "save_" + str(epochs))
            model.save(last_save_file)

            log.put(
                "maximum string accuracy:\t" +
                str(maximum_string_accuracy))
            log.put("patience:\t" + str(patience))
            log.put("save file:\t" + str(last_save_file))

        if countdown <= 0:
            keep_training = False

        countdown -= 1
        log.put("countdown:\t" + str(countdown))
        log.put("")

        epochs += 1

    log.put("Finished training!")
    log.close()

    return last_save_file
Ejemplo n.º 2
0
def train(model, data, params, last_save_file = None):
    """ Trains a model.

    Inputs:
        model (ATISModel): The model to train.
        data (ATISData): The data that is used to train.
        params (namespace): Training parameters.
    """
    # Get the training batches.
    if last_save_file:
        model.load(last_save_file)
    log = Logger(os.path.join(params.logdir, params.logfile), "w")
    num_train_original = atis_data.num_utterances(data.train_data)
    log.put("Original number of training utterances:\t"
            + str(num_train_original))

    eval_fn = evaluate_utterance_sample
    trainbatch_fn = data.get_utterance_batches
    trainsample_fn = data.get_random_utterances
    validsample_fn = data.get_all_utterances
    batch_size = params.batch_size
    if params.interaction_level:
        batch_size = 1
        eval_fn = evaluate_interaction_sample
        trainbatch_fn = data.get_interaction_batches
        trainsample_fn = data.get_random_interactions
        validsample_fn = data.get_all_interactions

    maximum_output_length = params.train_maximum_sql_length
    train_batches = trainbatch_fn(batch_size,
                                  max_output_length=maximum_output_length,
                                  randomize=not params.deterministic)

    if params.num_train >= 0:
        train_batches = train_batches[:params.num_train]

    training_sample = trainsample_fn(params.train_evaluation_size,
                                     max_output_length=maximum_output_length)
    valid_examples = validsample_fn(data.valid_data,
                                    max_output_length=maximum_output_length)

    num_train_examples = sum([len(batch) for batch in train_batches])
    num_steps_per_epoch = len(train_batches)

    log.put(
        "Actual number of used training examples:\t" +
        str(num_train_examples))
    log.put("(Shortened by output limit of " +
            str(maximum_output_length) +
            ")")
    log.put("Number of steps per epoch:\t" + str(num_steps_per_epoch))
    log.put("Batch size:\t" + str(batch_size))

    print(
        "Kept " +
        str(num_train_examples) +
        "/" +
        str(num_train_original) +
        " examples")
    print(
        "Batch size of " +
        str(batch_size) +
        " gives " +
        str(num_steps_per_epoch) +
        " steps per epoch")

    # Keeping track of things during training.
    epochs = 0
    patience = params.initial_patience
    learning_rate_coefficient = 1.
    previous_epoch_loss = float('inf')
    maximum_validation_accuracy = 0.
    maximum_string_accuracy = 0.
    #crayon = CrayonClient(hostname="localhost")
    #experiment = crayon.create_experiment(params.logdir)

    countdown = int(patience)

    keep_training = True
    while keep_training:
        log.put("Epoch:\t" + str(epochs))
        model.set_dropout(params.dropout_amount)
        model.set_learning_rate(
            learning_rate_coefficient *
            params.initial_learning_rate)
        # Run a training step.
        if params.interaction_level:
            epoch_loss = train_epoch_with_interactions(
                train_batches,
                params,
                model,
                randomize=not params.deterministic)
        else:
            epoch_loss = train_epoch_with_utterances(
                train_batches,
                model,
                randomize=not params.deterministic)
        log.put("train epoch loss:\t" + str(epoch_loss))
        #experiment.add_scalar_value("train_loss", epoch_loss, step=epochs)
        model.set_dropout(0.)
        # Run an evaluation step on a sample of the training data.
        train_eval_results = eval_fn(training_sample,
                                     model,
                                     params.train_maximum_sql_length,
                                     "evals/train-eval",
                                     gold_forcing=True,
                                     metrics=TRAIN_EVAL_METRICS,
                                     write_results=True)[0]

        for name, value in train_eval_results.items():
            log.put(
                "train final gold-passing " +
                name.name +
                ":\t" +
                "%.2f" %
                value)
            #experiment.add_scalar_value(
            #    "train_gold_" + name.name, value, step=epochs)



        # Run an evaluation step on the validation set.
        if params.new_version:
            suffix = "-new"
        else:
            suffix = "-old"
        valid_eval_results = eval_fn(valid_examples,
                                     model,
                                     params.train_maximum_sql_length,
                                     "evals/valid-eval" + suffix + str(epochs),
                                     gold_forcing=True,
                                     database_username=params.database_username,
                                     database_password=params.database_password,
                                     database_timeout=params.database_timeout,
                                     metrics=VALID_EVAL_METRICS_WITHOUT_MYSQL,
                                     write_results=True)[0]
        for name, value in valid_eval_results.items():
            log.put("valid gold-passing " + name.name + ":\t" + "%.2f" % value)
            #experiment.add_scalar_value(
            #    "valid_gold_" + name.name, value, step=epochs)


        valid_loss = valid_eval_results[Metrics.LOSS]
        valid_token_accuracy = valid_eval_results[Metrics.TOKEN_ACCURACY]
        string_accuracy = valid_eval_results[Metrics.STRING_ACCURACY]

        if valid_loss > previous_epoch_loss:
            learning_rate_coefficient *= params.learning_rate_ratio
            log.put(
                "learning rate coefficient:\t" +
                str(learning_rate_coefficient))
        #experiment.add_scalar_value(
        #    "learning_rate",
        #    learning_rate_coefficient,
        #    step=epochs)
        previous_epoch_loss = valid_loss
        saved = False
        if valid_token_accuracy > maximum_validation_accuracy:
            saved = True
            maximum_validation_accuracy = valid_token_accuracy
            patience = patience * params.patience_ratio
            countdown = int(patience)
            last_save_file = os.path.join(params.logdir, "save_" + str(epochs))
            model.save(last_save_file)

            log.put("maximum accuracy:\t" + str(maximum_validation_accuracy))
            log.put("patience:\t" + str(patience))
            log.put("save file:\t" + str(last_save_file))
        if not saved and string_accuracy > maximum_string_accuracy:
            maximum_string_accuracy = string_accuracy
            log.put(
                "maximum string accuracy:\t" +
                str(maximum_string_accuracy))
            save_path = "save_" + str(epochs)
            if params.interaction_level:
                save_path += "-interaction"
            else:
                save_path += "-utterance"
            last_save_file = os.path.join(params.logdir, "save_" + str(epochs))
            model.save(last_save_file)

        send_slack_message(
            username=params.logdir,
            message="Epoch " +
            str(epochs) +
            ": " +
            str(string_accuracy) +
            " validation accuracy; countdown is " +
            str(countdown),
            channel="models")

        if countdown <= 0:
            keep_training = False

        countdown -= 1
        log.put("countdown:\t" + str(countdown))
        #experiment.add_scalar_value("countdown", countdown, step=epochs)
        log.put("")

        epochs += 1

    log.put("Finished training!")
    send_slack_message(username=params.logdir,
                       message="Done training!!",
                       channel="@alsuhr")
    log.close()

    return last_save_file