Beispiel #1
0
def train(network, x, y, epochs=3):
    """Train network for given number of epochs"""
    for epoch in range(epochs):
        print('epoch', epoch + 1)
        losses = (
            zip(x, y) >> nf.PrintProgress(x) >> vec2img >> augment >>
            nf.Shuffle(1000) >> build_batch >> network.train() >> nf.Collect())
        print('train loss:', np.mean(losses))
Beispiel #2
0
def train(network, epochs=3):
    """Train network for given number of epochs"""
    print('loading data...')
    filepath = download_mnist()
    x_train, y_train, x_test, y_test = load_mnist(filepath)

    plot = nm.PlotLines(None, every_sec=0.2)
    build_batch = (nm.BuildBatch(128, verbose=False).input(
        0, 'vector', 'float32').output(1, 'number', 'int64'))

    for epoch in range(epochs):
        print('epoch', epoch + 1)
        losses = (zip(x_train,
                      y_train) >> nf.PrintProgress(x_train) >> nf.Shuffle(1000)
                  >> build_batch >> network.train() >> plot >> nf.Collect())
        acc_test = evaluate(network, x_test, y_test)
        acc_train = evaluate(network, x_train, y_train)
        print('train loss : {:.6f}'.format(np.mean(losses)))
        print('train acc  : {:.1f}'.format(acc_train))
        print('test acc   : {:.1f}'.format(acc_test))
Beispiel #3
0
def predict(network, x, y):
    """Compute network outputs and print accuracy"""
    preds = (zip(x, y) >> nf.PrintProgress(x) >> vec2img >> build_pred_batch >>
             network.predict() >> nf.Collect())
    acc = accuracy(y, preds)
    print('test acc', acc)
Beispiel #4
0
def validate(network, x, y):
    """Compute validation/test loss (= mean over batch losses)"""
    losses = (zip(x, y) >> nf.PrintProgress(x) >> vec2img >> build_batch >>
              network.validate() >> nf.Collect())
    print('val loss:', np.mean(losses))
Beispiel #5
0
def evaluate(network, x, y):
    """Evaluate network performance (here accuracy)"""
    metrics = [accuracy]
    result = (zip(x, y) >> nf.PrintProgress(x) >> vec2img >> build_batch >>
              network.evaluate(metrics))
    return result