コード例 #1
0
def test_lenet(devs, kv_type):
    # guarantee the same weight init for each run
    mx.random.seed(0)
    logging.basicConfig(level=logging.DEBUG)

    # (train, val) = common.cifar10(batch_size = 128, input_shape=(3,28,28))
    (train, val) = common.mnist(batch_size=100, input_shape=(1, 28, 28))

    model = mx.model.FeedForward.create(ctx=devs,
                                        kvstore=kv_type,
                                        symbol=common.lenet(),
                                        X=train,
                                        num_epoch=3,
                                        learning_rate=0.1,
                                        momentum=0.9,
                                        wd=0.00001)

    return common.accuracy(model, val)
コード例 #2
0
ファイル: local_lenet.py プロジェクト: reking/mxnet
def test_lenet(devs, kv_type):
    # guarantee the same weight init for each run
    mx.random.seed(0)
    logging.basicConfig(level=logging.DEBUG)

    # (train, val) = common.cifar10(batch_size = 128, input_shape=(3,28,28))
    (train, val) = common.mnist(batch_size = 100, input_shape=(1,28,28))

    model = mx.model.FeedForward.create(
        ctx           = devs,
        kvstore       = kv_type,
        symbol        = common.lenet(),
        X             = train,
        num_round     = 3,
        learning_rate = 0.1,
        momentum      = 0.9,
        wd            = 0.00001)

    return common.accuracy(model, val)
コード例 #3
0
#!/usr/bin/env python
import common
import mxnet as mx
import logging

mx.random.seed(0)
logging.basicConfig(level=logging.DEBUG)

kv = mx.kvstore.create('dist_async')

(train, val) = common.mnist(num_parts=kv.num_workers,
                            part_index=kv.rank,
                            batch_size=100,
                            input_shape=(1, 28, 28))

model = mx.model.FeedForward.create(ctx=mx.gpu(kv.rank),
                                    kvstore=kv,
                                    symbol=common.lenet(),
                                    X=train,
                                    num_epoch=10,
                                    learning_rate=0.05,
                                    momentum=0.9,
                                    wd=0.00001)

common.accuracy(model, val)
コード例 #4
0
ファイル: dist_async_lenet.py プロジェクト: Aspart/mxnet
#!/usr/bin/env python
import common
import mxnet as mx
import logging

mx.random.seed(0)
logging.basicConfig(level=logging.DEBUG)

kv = mx.kvstore.create('dist_async')

(train, val) = common.mnist(num_parts = kv.num_workers,
                            part_index = kv.rank,
                            batch_size = 100,
                            input_shape = (1,28,28))

model  = mx.model.FeedForward.create(
    ctx           = mx.gpu(kv.rank),
    kvstore       = kv,
    symbol        = common.lenet(),
    X             = train,
    num_epoch     = 10,
    learning_rate = 0.05,
    momentum      = 0.9,
    wd            = 0.00001)

common.accuracy(model, val)