예제 #1
0
def convert_net(sym, args):
    # setup context
    ctx = mx.cpu(0)

    # weight_map = get_weight_map(args.step_old, args.is_bin_old,
    #                             args.step_new, args.is_bin_new)
    # load params
    arg_params, aux_params = load_param(args.params, ctx=ctx)

    # produce shape max possible
    data_names = ['data', 'im_info']
    label_names = None
    data_shapes = [('data', (1, 3, args.img_long_side, args.img_long_side)),
                   ('im_info', (1, 3))]
    label_shapes = None

    # check shapes
    check_shape(sym, data_shapes, arg_params, aux_params)

    # create and bind module
    mod = Module(sym, data_names, label_names, context=ctx)
    mod.bind(data_shapes, label_shapes, for_training=False)
    mod.init_params(arg_params=arg_params, aux_params=aux_params)

    # forward
    mod.save_checkpoint(args.save_prefix, epoch=0)
def dummy_data(ctx, batch_size=1):
    return [
        mx.nd.random.uniform(shape=shape, ctx=ctx)
        for shape in ([batch_size, 3, 600, 600], [batch_size])
    ]


data_names = ['data']
label_names = None
data_shapes = [('data', (1, 3, 1000, 600))]
label_shapes = None

data = mx.symbol.Variable(name="data")
GLUON_LAYER = VGGConvBlock(isBin=True, step=4)
GLUON_LAYER.hybridize()
conv_feat = GLUON_LAYER(data)

arg_params, aux_params = load_param(
    "/home/skutukov/work/mxnet_fasterrcnn_binary/convert/temp-0000.params",
    ctx=mx.cpu())
check_shape(conv_feat, data_shapes, arg_params, aux_params)

mod = Module(conv_feat, data_names, label_names, context=mx.cpu())
mod.bind(data_shapes, label_shapes, for_training=False)
mod.init_params(arg_params=arg_params, aux_params=aux_params)

data1, _ = dummy_data(ctx=mx.cpu())
# mod.forward(data1)
mod.save_checkpoint('test_vgg', epoch=0)