示例#1
0
文件: mnist.py 项目: nvagus/tforce
def main():
    model = Model()
    stream = t4.MultiNpzDataStream(
        {'train': '/data/plan/mnist/mnist-train.npz', 'valid': '/data/plan/mnist/mnist-valid.npz'},
        'image', 'label'
    )
    model.setup(stream)
    with model.using_workers():
        stream.selected = 'train'
        t4.trainer.Alice(model.train).run(600, 1)
        stream.selected = 'valid'
        t4.trainer.Bob(model.valid).run()
    return 0
示例#2
0
文件: imagenet.py 项目: nvagus/tforce
def main():
    model = Model()
    stream = t4.MultiNpzDataStream(
        {'train': '/data/plan/coco/coco-train.npz', 'valid': '/data/plan/coco/coco-valid.npz'},
        'image', 'label'
    )
    model.setup(stream)
    with model.using_workers(50):
        model.restore('/data/noise/models/core.npz')

        stream.selected = 'train'
        t4.trainer.Alice(model.train).run(800, 1)
        stream.selected = 'valid'
        t4.trainer.Bob(model.valid, t4.trainer.Alice.default.batch_size).run(100)
    return 0
示例#3
0
def main():
    model = Model()
    stream = t4.MultiNpzDataStream(
        {
            'train': '/data/plan/mnist/mnist-train.npz',
            'valid': '/data/plan/mnist/mnist-valid.npz'
        }, 'image', 'label')
    # stream = t4.NpzDataStream('/data/plan/mnist/mnist-valid.npz', 'image', 'label')
    model.setup(stream)
    with stream.using_workers():
        stream.selected = 'train'
        print(model.sess.run(stream.batch['label'],
                             feed_dict=stream.givens(1)))
        stream.selected = 'valid'
        print(
            model.sess.run(stream.batch['label'], feed_dict=stream.givens(10)))
    return 0
示例#4
0
文件: residual.py 项目: nvagus/tforce
def main():
    model = Model()
    stream = t4.MultiNpzDataStream(
        {
            'train': '/data/plan/cifar-10/cifar-10-train.npz',
            'valid': '/data/plan/cifar-10/cifar-10-valid.npz'
        }, 'image', 'label')
    model.setup(stream)
    lr = 0.1
    with model.using_workers():
        for epoch in range(200):
            if epoch in [60, 120, 160]:
                lr *= 0.2
            print(f'epoch {epoch} ' + '-' * 100)
            stream.selected = 'train'
            t4.trainer.Alice(
                model.train).run(stream.data['train'].size //
                                 t4.trainer.Alice.default.batch_size,
                                 1,
                                 givens={model.lr: lr})
            stream.selected = 'valid'
            t4.trainer.Bob(model.valid, 100).run(100, highlight=True)
    return 0