Ejemplo n.º 1
0
def _validate_schema(cfg: dict):
    """Validate config params against schema.

    Args:
        cfg: dict, config.

    Raises:
        jsonschema.ValidationError: invalid config params.
    """
    schema = io_utils.load_json(
        "schema.json", os.path.abspath(os.path.dirname(__file__))
    )
    try:
        jsonschema.validate(cfg, schema)
    except jsonschema.ValidationError as err:
        raise jsonschema.ValidationError(f"invalid config: {err}")

    # Check model outputs have unique names
    num_outputs = len(cfg["model"]["outputs"])
    num_unique_names = len(set([o["name"] for o in cfg["model"]["outputs"]]))
    if num_outputs != num_unique_names:
        raise jsonschema.ValidationError(f"'outputs' names are not unique")

    # Check that multi-output networks have loss weights for each output
    if num_outputs > 1:
        for output in cfg["model"]["outputs"]:
            if "loss_weight" not in output:
                raise jsonschema.ValidationError(
                    "'outputs' requires 'loss_weights' for multiple outputs"
                )
Ejemplo n.º 2
0
def test_multi_output(artifact_dir):
    records_train = gen_records(NUM_SAMPLES_TRAIN)
    records_validation = gen_records(NUM_SAMPLES_VALIDATION)
    records_score = gen_records(NUM_SAMPLES_SCORE)

    loc = os.path.abspath(os.path.dirname(__file__))
    cfg = io_utils.load_json("config_multi_output.json", loc)

    bm = BarrageModel(artifact_dir)
    bm.train(cfg, records_train, records_validation)
    scores = bm.predict(records_score)

    classification = [np.argmax(score["classification"]) for score in scores]
    regression_1 = [score["regression"][0] for score in scores]
    regression_2 = [score["regression"][1] for score in scores]

    df_scores = pd.DataFrame({
        "classification": classification,
        "regression_1": regression_1,
        "regression_2": regression_2,
    })

    assert (df_scores["classification"] == records_score["y_cls"]).mean() > 0.5
    assert abs(
        (df_scores["regression_1"] - records_score["y_reg_1"]).mean()) < 0.5
    assert abs(
        (df_scores["regression_2"] - records_score["y_reg_2"]).mean()) < 0.5
Ejemplo n.º 3
0
def test_load_json(artifact_path, sample_dict):
    filename = "unit_test.json"
    with open(os.path.join(artifact_path, filename), "w") as fn:
        json.dump(sample_dict, fn)
    assert os.path.isfile(os.path.join(artifact_path, filename))

    obj = io_utils.load_json(filename, artifact_path)
    assert obj == sample_dict
Ejemplo n.º 4
0
def test_simple_output(artifact_dir, records_train, records_validation,
                       records_score):
    loc = os.path.abspath(os.path.dirname(__file__))
    cfg = io_utils.load_json("config_single_output.json", loc)

    bm = BarrageModel(artifact_dir)
    bm.train(cfg, records_train, records_validation)
    scores = bm.predict(records_score)

    df_scores = pd.DataFrame(scores)
    assert (df_scores["softmax"] == records_score["label"]).mean() >= 0.90
Ejemplo n.º 5
0
def test_simple_output(artifact_dir):
    records_train = gen_records(NUM_SAMPLES_TRAIN)
    records_validation = gen_records(NUM_SAMPLES_VALIDATION)
    records_score = gen_records(NUM_SAMPLES_SCORE)

    loc = os.path.abspath(os.path.dirname(__file__))
    cfg = io_utils.load_json("config_single_output.json", loc)

    bm = BarrageModel(artifact_dir)
    bm.train(cfg, records_train, records_validation)
    scores = bm.predict(records_score)

    df_scores = pd.DataFrame(scores)
    records_score = pd.DataFrame(records_score)
    assert (df_scores["softmax"] == records_score["label"]).mean() >= 0.90
Ejemplo n.º 6
0
def train(config, train_data, validation_data, artifact_dir):
    """Barrage deep learning train.

    Supported filetypes:

        1. .csv

        2. .json

    Args:

        config: filepath to barrage config [REQUIRED].

        train-data: filepath to train data [REQUIRED].

        validation-data: filepath to validation data [REQUIRED].

    Note: artifact-dir cannot already exist.
    """
    cfg = io_utils.load_json(config)
    records_train = io_utils.load_data(train_data)
    records_validation = io_utils.load_data(train_data)
    BarrageModel(artifact_dir).train(cfg, records_train, records_validation)
Ejemplo n.º 7
0
    def postprocess(self, score):
        # Threshold 0.5 / Argmax the score

        if len(score) == 1:
            score[self.out_key] = float(score[self.out_key] > 0.5)
        else:
            score[self.out_key] = np.argmax(score[self.out_key])

        return score

    def load(self, path):
        self.tokenizer = io_utils.load_pickle("tokenizer.pkl", path)

    def save(self, path):
        io_utils.save_pickle(self.tokenizer, "tokenizer.pkl", path)


if __name__ == "__main__":
    records_train, records_val, records_test = get_data()

    # Train
    cfg = io_utils.load_json("config_sentiment.json")
    BarrageModel("artifacts").train(cfg, records_train, records_val)

    # Predict
    scores = BarrageModel("artifacts").predict(records_test)
    df_preds = pd.DataFrame(scores)

    acc = (df_preds["target"] == records_test["label"]).mean()
    print(f"Test set accuracy: {acc}")
Ejemplo n.º 8
0
MNIST dataset example
"""
import numpy as np
from tensorflow.keras import datasets

from barrage import BarrageModel
from barrage.utils import io_utils


def get_data():
    """Load MNIST dataset."""
    (X_train, y_train), (X_val, y_val) = datasets.mnist.load_data()
    X_train = X_train[:, ..., np.newaxis]  # need image shape (28, 28, 1) not (28, 28)
    X_val = X_val[:, ..., np.newaxis]  # need image shape (28, 28, 1) not (28, 28)

    # Convert to list of dicts
    samples_train = X_train.shape[0]
    records_train = [
        {"x": X_train[ii, ...], "y": y_train[ii]} for ii in range(samples_train)
    ]
    samples_val = X_val.shape[0]
    records_val = [{"x": X_val[ii, ...], "y": y_val[ii]} for ii in range(samples_val)]
    return records_train, records_val


if __name__ == "__main__":
    records_train, records_val = get_data()
    # Train
    cfg = io_utils.load_json("config_mnist.json")
    BarrageModel("artifacts").train(cfg, records_train, records_val)