コード例 #1
0
ファイル: test_symbol.py プロジェクト: nick6918/hello
def test_symbol_saveload():
    sym = models.mlp2()
    fname = 'tmp_sym.json'
    sym.save(fname)
    data2 = mx.symbol.load(fname)
    # save because of order
    assert sym.tojson() == data2.tojson()
    os.remove(fname)
コード例 #2
0
ファイル: test_infer_shape.py プロジェクト: antinucleon/mxnet
def test_mlp2_infer_shape():
    # Build MLP
    out = models.mlp2()
    # infer shape
    data_shape = (100, 100)
    arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape)
    arg_shape_dict = dict(zip(out.list_arguments(), arg_shapes))
    assert len(out_shapes) == 1
    assert out_shapes[0] == (100, 10)
    true_shapes = {'fc2_bias': (10,),
                   'fc2_weight' : (10, 1000),
                   'fc1_bias' : (1000,),
                   'fc1_weight' : (1000,100) }
    for k, v in true_shapes.items():
        assert arg_shape_dict[k] == v
コード例 #3
0
def test_symbol_basic():
    print "======================test_symbol_basic========================"
    """models.mlp2 会产生一个含有两个隐藏层的mlp
    list_arguments: ['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias']
    list_outputs: ['fc2_output']"""
    mlist = []
    mlist.append(models.mlp2())
    for m in mlist:
        m.list_arguments()
        print 'list_argumens'
        print m.list_arguments()
        m.list_outputs()
        print 'list_outputs'
        print m.list_outputs()
    print "======================test_symbol_basic========================"
コード例 #4
0
ファイル: test_symbol.py プロジェクト: TangXing/mxnet
def test_symbol_basic():
    mlist = []
    mlist.append(models.mlp2())
    for m in mlist:
        m.list_arguments()
        m.list_outputs()
コード例 #5
0
ファイル: test_symbol.py プロジェクト: nick6918/hello
def test_symbol_pickle():
    mlist = [models.mlp2(), models.conv()]
    data = pkl.dumps(mlist)
    mlist2 = pkl.loads(data)
    for x, y  in zip(mlist, mlist2):
        assert x.tojson() == y.tojson()
コード例 #6
0
ファイル: test_infer_shape.py プロジェクト: antinucleon/mxnet
def test_mlp2_infer_error():
    # Test shape inconsistent case
    out = models.mlp2()
    weight_shape= (1, 100)
    data_shape = (100, 100)
    arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape, fc1_weight=weight_shape)
コード例 #7
0
def test_symbol_basic():
    mlist = []
    mlist.append(models.mlp2())
    for m in mlist:
        m.list_arguments()
        m.list_outputs()
コード例 #8
0
def test_symbol_pickle():
    mlist = [models.mlp2(), models.conv()]
    data = pkl.dumps(mlist)
    mlist2 = pkl.loads(data)
    for x, y in zip(mlist, mlist2):
        assert x.tojson() == y.tojson()
コード例 #9
0
def test_mlp2_infer_error():
    # Test shape inconsistent case
    out = models.mlp2()
    weight_shape= (1, 100)
    data_shape = (100, 100)
    arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape, fc1_weight=weight_shape)