Ejemplo n.º 1
0
def test_add_graph_symbol():
    data = mx.sym.Variable('data')
    conv = mx.sym.Convolution(data, kernel=(2, 2), num_filter=2)
    nodes = _get_nodes_from_symbol(conv)
    expected_nodes = [NodeDef(name='data', op='null'),
                      NodeDef(name='convolution0/convolution0_weight', op='null',
                              attr={'param': AttrValue(
                                  s='{ kernel :  (2, 2) ,  num_filter :  2 }'.encode(encoding='utf-8'))}),
                      NodeDef(name='convolution0/convolution0_bias', op='null',
                              attr={'param': AttrValue(
                                  s='{ kernel :  (2, 2) ,  num_filter :  2 }'.encode(encoding='utf-8'))}),
                      NodeDef(name='convolution0/convolution0', op='Convolution',
                              input=['data', 'convolution0/convolution0_weight', 'convolution0/convolution0_bias'],
                              attr={'param': AttrValue(
                                  s='{ kernel :  (2, 2) ,  num_filter :  2 }'.encode(encoding='utf-8'))})]
    # check _get_nodes_from_symbol
    for expected_node, node in zip(expected_nodes, nodes):
        assert expected_node == node

    # check _sym2pb
    expected_graph = GraphDef(node=expected_nodes, versions=VersionDef(producer=100))
    graph = _net2pb(conv)
    assert expected_graph == graph

    # check add_graph
    with SummaryWriter(logdir=_LOGDIR) as sw:
        sw.add_graph(conv)
    check_event_file_and_remove_logdir()
Ejemplo n.º 2
0
def test_add_graph_gluon():
    net = nn.HybridSequential()
    with net.name_scope():
        net.add(nn.Dense(128, activation='relu'))

    net.hybridize()
    net.initialize()
    net.forward(mx.nd.ones(1,))
    _, sym = net._cached_graph
    nodes = _get_nodes_from_symbol(sym)
    expected_nodes = [NodeDef(name='data', op='null'),
                      NodeDef(name='hybridsequential0_dense0_fwd/hybridsequential0_dense0_weight', op='null',
                              attr={'param': AttrValue(
                                  s='{ __dtype__ :  0 ,  __lr_mult__ :  1.0 ,  __shape__ :  '
                                    '(128, 0) ,  __wd_mult__ :  1.0 }'.encode(encoding='utf-8'))}),
                      NodeDef(name='hybridsequential0_dense0_fwd/hybridsequential0_dense0_bias', op='null',
                              attr={'param': AttrValue(
                                  s='{ __dtype__ :  0 ,  __init__ :  zeros ,  __lr_mult__ :  1.0 ,  __shape__ :  '
                                    '(128,) ,  __wd_mult__ :  1.0 }'.encode(encoding='utf-8'))}),
                      NodeDef(name='hybridsequential0_dense0_fwd/hybridsequential0_dense0_fwd', op='FullyConnected',
                              input=['data', 'hybridsequential0_dense0_fwd/hybridsequential0_dense0_weight',
                                     'hybridsequential0_dense0_fwd/hybridsequential0_dense0_bias'],
                              attr={'param': AttrValue(
                                  s='{ flatten :  True ,  no_bias :  False ,  '
                                    'num_hidden :  128 }'.encode(encoding='utf-8'))}),
                      NodeDef(name='hybridsequential0_dense0_relu_fwd/hybridsequential0_dense0_relu_fwd',
                              op='Activation', input=['hybridsequential0_dense0_fwd/hybridsequential0_dense0_fwd'],
                              attr={'param': AttrValue(
                                  s='{ act_type :  relu }'.encode(encoding='utf-8'))})
                      ]
    # check _get_nodes_from_symbol
    for expected_node, node in zip(expected_nodes, nodes):
        assert expected_node == node

    # check _sym2pb
    expected_graph = GraphDef(node=expected_nodes, versions=VersionDef(producer=100))
    graph = _net2pb(net)
    assert expected_graph == graph

    # check add_graph
    with SummaryWriter(logdir=_LOGDIR) as sw:
        sw.add_graph(net)
    check_event_file_and_remove_logdir()