def test_MNISTIter(): # prepare data get_data.GetMNIST_ubyte() batch_size = 100 train_dataiter = mx.io.MNISTIter(image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", data_shape=(784, ), batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10) # test_loop nbatch = 60000 / batch_size batch_count = 0 for batch in train_dataiter: batch_count += 1 assert (nbatch == batch_count) # test_reset train_dataiter.reset() train_dataiter.iter_next() label_0 = train_dataiter.getlabel().asnumpy().flatten() train_dataiter.iter_next() train_dataiter.iter_next() train_dataiter.iter_next() train_dataiter.iter_next() train_dataiter.reset() train_dataiter.iter_next() label_1 = train_dataiter.getlabel().asnumpy().flatten() assert (sum(label_0 - label_1) == 0)
def get_iters(): # check data get_data.GetMNIST_ubyte() batch_size = 100 train_dataiter = mx.io.MNISTIter(image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", data_shape=(1, 28, 28), label_name='sm_label', batch_size=batch_size, shuffle=True, flat=False, silent=False, seed=10) val_dataiter = mx.io.MNISTIter(image="data/t10k-images-idx3-ubyte", label="data/t10k-labels-idx1-ubyte", data_shape=(1, 28, 28), label_name='sm_label', batch_size=batch_size, shuffle=True, flat=False, silent=False) return train_dataiter, val_dataiter
stride=(2, 2), pool_type='max') fl = mx.symbol.Flatten(data=mp2, name="flatten") fc2 = mx.symbol.FullyConnected(data=fl, name='fc2', num_hidden=10) softmax = mx.symbol.SoftmaxOutput(data=fc2, name='sm') num_epoch = 1 model = mx.model.FeedForward(softmax, mx.cpu(), num_epoch=num_epoch, learning_rate=0.1, wd=0.0001, momentum=0.9) # check data get_data.GetMNIST_ubyte() train_dataiter = mx.io.MNISTIter(image="data/train-images-idx3-ubyte", label="data/train-labels-idx1-ubyte", data_shape=(1, 28, 28), label_name='sm_label', batch_size=batch_size, shuffle=True, flat=False, silent=False, seed=10) val_dataiter = mx.io.MNISTIter(image="data/t10k-images-idx3-ubyte", label="data/t10k-labels-idx1-ubyte", data_shape=(1, 28, 28), label_name='sm_label', batch_size=batch_size,