Ejemplo n.º 1
0
def cifar10_iterator(batch_size,
                     data_shape,
                     resize=-1,
                     num_parts=1,
                     part_index=0):
    get_data.GetCifar10()

    train = mx.io.ImageRecordIter(
        path_imgrec="data/cifar/train.rec",
        # mean_img    = "data/cifar/mean.bin",
        resize=resize,
        data_shape=data_shape,
        batch_size=batch_size,
        rand_crop=True,
        rand_mirror=True,
        num_parts=num_parts,
        part_index=part_index)

    val = mx.io.ImageRecordIter(
        path_imgrec="data/cifar/test.rec",
        # mean_img    = "data/cifar/mean.bin",
        resize=resize,
        rand_crop=False,
        rand_mirror=False,
        data_shape=data_shape,
        batch_size=batch_size,
        num_parts=num_parts,
        part_index=part_index)

    return train, val
Ejemplo n.º 2
0
def cifar10(batch_size, input_shape, num_parts=1, part_index=0):
    """return cifar10 iterator"""
    get_data.GetCifar10()

    train = mx.io.ImageRecordIter(path_imgrec="data/cifar/train.rec",
                                  mean_img="data/cifar/cifar_mean.bin",
                                  data_shape=input_shape,
                                  batch_size=batch_size,
                                  rand_crop=False,
                                  rand_mirror=False,
                                  shuffle=False,
                                  round_batch=False,
                                  num_parts=num_parts,
                                  part_index=part_index)
    val = mx.io.ImageRecordIter(path_imgrec="data/cifar/test.rec",
                                mean_img="data/cifar/cifar_mean.bin",
                                rand_crop=False,
                                rand_mirror=False,
                                shuffle=False,
                                round_batch=False,
                                data_shape=input_shape,
                                batch_size=batch_size)
    return (train, val)
Ejemplo n.º 3
0
def data(data_dir, batch_size, num_parts=1, part_index=0):
    """return cifar10 iterator"""
    if data_dir == "data/cifar/":
        sys.path.insert(0, "../../tests/python/common")
        import get_data
        get_data.GetCifar10()

    input_shape = (3, 28, 28)
    train = mx.io.ImageRecordIter(path_imgrec=data_dir + "train.rec",
                                  mean_img=data_dir + "cifar_mean.bin",
                                  data_shape=input_shape,
                                  batch_size=batch_size,
                                  rand_crop=True,
                                  rand_mirror=True,
                                  num_parts=num_parts,
                                  part_index=part_index)
    val = mx.io.ImageRecordIter(path_imgrec=data_dir + "test.rec",
                                mean_img=data_dir + "cifar_mean.bin",
                                rand_crop=False,
                                rand_mirror=False,
                                data_shape=input_shape,
                                batch_size=batch_size)
    return (train, val)
Ejemplo n.º 4
0
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
in4e = DownsampleFactory(in4d, 96)
in5a = SimpleFactory(in4e, 176, 160)
in5b = SimpleFactory(in5a, 176, 160)
pool = mx.symbol.Pooling(data=in5b,
                         pool_type="avg",
                         kernel=(7, 7),
                         name="global_pool")
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
softmax = mx.symbol.SoftmaxOutput(data=fc, name="loss")

#########################################################

get_data.GetCifar10()
batch_size = 128
num_epoch = 10
num_gpus = 1

train_dataiter = mx.io.ImageRecordIter(path_imgrec="data/cifar/train.rec",
                                       mean_img="data/cifar/cifar_mean.bin",
                                       rand_crop=True,
                                       rand_mirror=True,
                                       data_shape=(3, 28, 28),
                                       batch_size=batch_size,
                                       preprocess_threads=1,
                                       label_name='loss_label')
test_dataiter = mx.io.ImageRecordIter(path_imgrec="data/cifar/test.rec",
                                      mean_img="data/cifar/cifar_mean.bin",
                                      rand_crop=False,