예제 #1
0
def main():
    device = set_device(FLAGS.device)
    fluid.enable_dygraph(device) if FLAGS.dynamic else None

    train_dataset = MnistDataset(mode='train')
    val_dataset = MnistDataset(mode='test')

    inputs = [Input([None, 784], 'float32', name='image')]
    labels = [Input([None, 1], 'int64', name='label')]

    model = MNIST()
    optim = Momentum(learning_rate=FLAGS.lr,
                     momentum=.9,
                     parameter_list=model.parameters())

    model.prepare(optim,
                  CrossEntropy(),
                  Accuracy(topk=(1, 2)),
                  inputs,
                  labels,
                  device=FLAGS.device)
    if FLAGS.resume is not None:
        model.load(FLAGS.resume)

    model.fit(train_dataset,
              val_dataset,
              epochs=FLAGS.epoch,
              batch_size=FLAGS.batch_size,
              save_dir='mnist_checkpoint')
예제 #2
0
파일: mnist.py 프로젝트: zdqf/hapi
def main():
    paddle.enable_static() if FLAGS.static else None
    device = paddle.set_device(FLAGS.device)

    train_dataset = MNIST(mode='train')
    val_dataset = MNIST(mode='test')

    inputs = [Input(shape=[None, 1, 28, 28], dtype='float32', name='image')]
    labels = [Input(shape=[None, 1], dtype='int64', name='label')]

    net = LeNet()
    model = paddle.Model(net, inputs, labels)

    optim = Momentum(learning_rate=FLAGS.lr,
                     momentum=.9,
                     parameter_list=model.parameters())

    model.prepare(optim, paddle.nn.CrossEntropyLoss(),
                  paddle.metric.Accuracy(topk=(1, 2)))

    if FLAGS.resume is not None:
        model.load(FLAGS.resume)

    if FLAGS.eval_only:
        model.evaluate(val_dataset, batch_size=FLAGS.batch_size)
        return

    model.fit(train_dataset,
              val_dataset,
              epochs=FLAGS.epoch,
              batch_size=FLAGS.batch_size,
              save_dir=FLAGS.output_dir)
예제 #3
0
def main():
    @contextlib.contextmanager
    def null_guard():
        yield

    guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard()

    if not os.path.exists('mnist_checkpoints'):
        os.mkdir('mnist_checkpoints')

    train_loader = fluid.io.xmap_readers(
        lambda b: [
            np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
            np.array([x[1] for x in b]).reshape(-1, 1)
        ],
        paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4),
                     batch_size=FLAGS.batch_size,
                     drop_last=True), 1, 1)
    val_loader = fluid.io.xmap_readers(
        lambda b: [
            np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
            np.array([x[1] for x in b]).reshape(-1, 1)
        ],
        paddle.batch(paddle.dataset.mnist.test(),
                     batch_size=FLAGS.batch_size,
                     drop_last=True), 1, 1)

    device_ids = list(range(FLAGS.num_devices))

    with guard:
        model = MNIST()
        optim = Momentum(learning_rate=FLAGS.lr,
                         momentum=.9,
                         parameter_list=model.parameters())
        model.prepare(optim, CrossEntropy())
        if FLAGS.resume is not None:
            model.load(FLAGS.resume)

        for e in range(FLAGS.epoch):
            train_loss = 0.0
            train_acc = 0.0
            val_loss = 0.0
            val_acc = 0.0
            print("======== train epoch {} ========".format(e))
            for idx, batch in enumerate(train_loader()):
                outputs, losses = model.train(batch[0],
                                              batch[1],
                                              device='gpu',
                                              device_ids=device_ids)

                acc = accuracy(outputs[0], batch[1])[0]
                train_loss += np.sum(losses)
                train_acc += acc
                if idx % 10 == 0:
                    print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
                        idx, train_loss / (idx + 1), train_acc / (idx + 1)))

            print("======== eval epoch {} ========".format(e))
            for idx, batch in enumerate(val_loader()):
                outputs, losses = model.eval(batch[0],
                                             batch[1],
                                             device='gpu',
                                             device_ids=device_ids)

                acc = accuracy(outputs[0], batch[1])[0]
                val_loss += np.sum(losses)
                val_acc += acc
                if idx % 10 == 0:
                    print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
                        idx, val_loss / (idx + 1), val_acc / (idx + 1)))
            model.save('mnist_checkpoints/{:02d}'.format(e))