Esempio n. 1
0
def train(
    sentences_train,
    labels_train,
    sentences_valid,
    labels_valid,
    batch_size=128,
    n_epochs=10,
):
    train_dataset = data.TensorDataset(sentences_train, labels_train)
    valid_dataset = data.TensorDataset(sentences_valid, labels_valid)

    model = Network()

    train_loader = data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False
    )
    valid_loader = data.DataLoader(
        valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=False
    )

    databunch = DataBunch(train_dl=train_loader, valid_dl=valid_loader)
    learn = Learner(databunch, model, loss_func=loss)

    if torch.cuda.is_available():
        learn = learn.to_fp16()

    learn.fit_one_cycle(n_epochs)

    return learn.model
Esempio n. 2
0
def train(train_dataset: torch.utils.data.Dataset,
          test_dataset: torch.utils.data.Dataset,
          training_config: dict = train_config,
          global_config: dict = global_config):
    """
    Template training routine. Takes a training and a test dataset wrapped
    as torch.utils.data.Dataset type and two corresponding generic
    configs for both gobal path settings and training settings.
    Returns the fitted fastai.train.Learner object which can be
    used to assess the resulting metrics and error curves etc.
    """

    for path in global_config.values():
        create_dirs(path)

    # wrap datasets with Dataloader classes
    train_loader = torch.utils.data.DataLoader(
        train_dataset, **train_config["DATA_LOADER_CONFIG"])
    test_loader = torch.utils.data.DataLoader(
        test_dataset, **train_config["DATA_LOADER_CONFIG"])
    databunch = DataBunch(train_loader, test_loader)

    # instantiate model and learner
    if training_config["WEIGHTS"] is None:
        model = training_config["MODEL"](**training_config["MODEL_CONFIG"])
    else:
        model = load_model(training_config["MODEL"],
                           training_config["MODEL_CONFIG"],
                           training_config["WEIGHTS"],
                           training_config["DEVICE"])

    learner = Learner(databunch,
                      model,
                      metrics=train_config["METRICS"],
                      path=global_config["ROOT_PATH"],
                      model_dir=global_config["WEIGHT_DIR"],
                      loss_func=train_config["LOSS"])

    # model name & paths
    name = "_".join([train_config["DATE"], train_config["SESSION_NAME"]])
    modelpath = os.path.join(global_config["WEIGHT_DIR"], name)

    if train_config["MIXED_PRECISION"]:
        learner.to_fp16()

    learner.save(modelpath)

    torch.backends.cudnn.benchmark = True

    cbs = [
        SaveModelCallback(learner),
        LearnerTensorboardWriter(
            learner,
            Path(os.path.join(global_config["LOG_DIR"]), "tensorboardx"),
            name),
        TerminateOnNaNCallback()
    ]

    # perform training iteration
    try:
        if train_config["ONE_CYCLE"]:
            learner.fit_one_cycle(train_config["EPOCHS"],
                                  max_lr=train_config["LR"],
                                  callbacks=cbs)
        else:
            learner.fit(train_config["EPOCHS"],
                        lr=train_config["LR"],
                        callbacks=cbs)
    # save model files
    except KeyboardInterrupt:
        learner.save(modelpath)
        raise KeyboardInterrupt

    learner.save(modelpath)
    val_loss = min(learner.recorder.val_losses)
    val_metrics = learner.recorder.metrics

    # log using the logging tool
    logger = log.Log(train_config, run_name=train_config['SESSION_NAME'])
    logger.log_metric('Validation Loss', val_loss)
    logger.log.metrics(val_metrics)
    logger.end_run()

    #write csv log file
    log_content = train_config.copy()
    log_content["VAL_LOSS"] = val_loss
    log_content["VAL_METRICS"] = val_metrics
    log_path = os.path.join(global_config["LOG_DIR"], train_config["LOGFILE"])
    write_log(log_path, log_content)

    return learner, log_content, name