コード例 #1
0
ファイル: train.py プロジェクト: arlose/cv
batch_size = 64
data_shape = (1, 28, 28)
train = mx.io.ImageRecordIter(
        path_imgrec = "train1.rec",
        mean_img    = "mean.bin",
        data_shape  = data_shape,
        batch_size  = batch_size,
        rand_crop   = True,
        rand_mirror = True)
val = mx.io.ImageRecordIter(
        path_imgrec = "val1.rec",
        mean_img    = "mean.bin",
        data_shape  = data_shape,
        batch_size  = batch_size,
        rand_crop   = True,
        rand_mirror = True)

softmax = lenet.getsymbol()

num_round = 30
num_gpu = 1
gpus = [mx.gpu(i) for i in range(num_gpu)]

model = mx.model.FeedForward(ctx=gpus, symbol=softmax, num_epoch=num_round,
                             learning_rate=0.001, momentum=0.9, wd=0.00001)
model_prefix = "lenet"
model.fit(X=train, eval_data=val,
          eval_metric="accuracy",
          batch_end_callback=mx.callback.Speedometer(batch_size, 200),
          epoch_end_callback=mx.callback.do_checkpoint(model_prefix))
コード例 #2
0
ファイル: transfer.py プロジェクト: arlose/cv
        mean_img    = "mean.bin",
        data_shape  = data_shape,
        batch_size  = batch_size,
        rand_crop   = True,
        rand_mirror = True)

val_data = mx.io.ImageRecordIter(
        path_imgrec = "val2_11.rec",
        mean_img    = "mean.bin",
        data_shape  = data_shape,
        batch_size  = batch_size,
        rand_crop   = True,
        rand_mirror = True)

number = 11
softmax = lenet.getsymbol(number)

arg_shapes, output_shapes, aux_shapes = softmax.infer_shape(data=(1, 3, data_shape[1], data_shape[2]))
arg_names = softmax.list_arguments()
arg_dict = dict(zip(arg_names, [mx.nd.zeros(shape, ctx=ctx) for shape in arg_shapes]))

# load pretrained
model_prefix = "lenet_10"
epoch_load = 10
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch_load)

# init with pretrained weight
fixed_param_prefix = ['convolution0_bias', 'convolution1_bias', 'convolution2_bias','convolution0_weight', 'convolution1_weight', 'convolution2_weight']
for name in fixed_param_prefix:
    key = name
    if key in arg_params: