예제 #1
0
def find_best_checkpoint(log_dir):
    if log_dir.endswith("/"):
        name = os.path.basename(log_dir.rstrip("/"))
    else:
        name = os.path.basename(log_dir)

    roc_curve_dir = os.path.join(log_dir, "roc_curve")
    roc_curve = [each for each in os.listdir(roc_curve_dir)]

    if len(roc_curve) == 0:
        return None

    roc_curve = [os.path.splitext(each)[0] for each in roc_curve]
    parsed_roc_curve = [parse_str(each) for each in roc_curve]
    best = max(parsed_roc_curve, key=lambda each: each["auc"])

    best["name"] = name

    ckpt_dir = os.path.join(log_dir, "checkpoint")
    if not os.path.exists(ckpt_dir):
        ckpt_dir = os.path.join(log_dir, "model_checkpoint")

    for each in os.listdir(ckpt_dir):
        name = os.path.splitext(each)[0]

        if parse_str(name, "epoch") == best["epoch"]:
            best_path = os.path.join(ckpt_dir, each)
            best_path = os.path.abspath(best_path)
            break

    best["path"] = best_path

    return best
예제 #2
0
def evaluate(checkpoint_path,
             train_iter,
             test_iter,
             log_dir):

    model = load_model(checkpoint_path)

    ckpt_name = os.path.basename(checkpoint_path).replace(".hdf5", "")
    epoch = parse_str(ckpt_name, target="epoch")
    name = "model_epoch-{:02d}".format(epoch)

    title = "Quark/Gluon Jet Discrimination (Epoch {})".format(epoch)

    roc_curve = ROCCurve(
        name=name,
        title=title,
        directory=log_dir.roc_curve.path)

    model_response = BinaryClassifierResponse(
        name=name,
        title=title,
        directory=log_dir.model_response.path)

    ##########################
    # training data
    ###########################
    print("TRAINING SET")
    for batch_idx, batch in enumerate(train_iter, 1):

        y_score = model.predict_on_batch([batch.x])
        model_response.append(is_train=True,
                              y_true=batch.y,
                              y_score=y_score)

    #############################
    # Test on dijet dataset
    ########################
    print("TEST SET")
    for batch_idx, batch in enumerate(test_iter, 1):

        y_score = model.predict_on_batch([batch.x])
        model_response.append(is_train=False,
                              y_true=batch.y,
                              y_score=y_score)
        roc_curve.append(y_true=batch.y, y_score=y_score)

    roc_curve.finish()
    model_response.finish()