def test_add_graph_symbol(): data = mx.sym.Variable('data') conv = mx.sym.Convolution(data, kernel=(2, 2), num_filter=2) nodes = _get_nodes_from_symbol(conv) expected_nodes = [NodeDef(name='data', op='null'), NodeDef(name='convolution0/convolution0_weight', op='null', attr={'param': AttrValue( s='{ kernel : (2, 2) , num_filter : 2 }'.encode(encoding='utf-8'))}), NodeDef(name='convolution0/convolution0_bias', op='null', attr={'param': AttrValue( s='{ kernel : (2, 2) , num_filter : 2 }'.encode(encoding='utf-8'))}), NodeDef(name='convolution0/convolution0', op='Convolution', input=['data', 'convolution0/convolution0_weight', 'convolution0/convolution0_bias'], attr={'param': AttrValue( s='{ kernel : (2, 2) , num_filter : 2 }'.encode(encoding='utf-8'))})] # check _get_nodes_from_symbol for expected_node, node in zip(expected_nodes, nodes): assert expected_node == node # check _sym2pb expected_graph = GraphDef(node=expected_nodes, versions=VersionDef(producer=100)) graph = _net2pb(conv) assert expected_graph == graph # check add_graph with SummaryWriter(logdir=_LOGDIR) as sw: sw.add_graph(conv) check_event_file_and_remove_logdir()
def test_add_graph_gluon(): net = nn.HybridSequential() with net.name_scope(): net.add(nn.Dense(128, activation='relu')) net.hybridize() net.initialize() net.forward(mx.nd.ones(1,)) _, sym = net._cached_graph nodes = _get_nodes_from_symbol(sym) expected_nodes = [NodeDef(name='data', op='null'), NodeDef(name='hybridsequential0_dense0_fwd/hybridsequential0_dense0_weight', op='null', attr={'param': AttrValue( s='{ __dtype__ : 0 , __lr_mult__ : 1.0 , __shape__ : ' '(128, 0) , __wd_mult__ : 1.0 }'.encode(encoding='utf-8'))}), NodeDef(name='hybridsequential0_dense0_fwd/hybridsequential0_dense0_bias', op='null', attr={'param': AttrValue( s='{ __dtype__ : 0 , __init__ : zeros , __lr_mult__ : 1.0 , __shape__ : ' '(128,) , __wd_mult__ : 1.0 }'.encode(encoding='utf-8'))}), NodeDef(name='hybridsequential0_dense0_fwd/hybridsequential0_dense0_fwd', op='FullyConnected', input=['data', 'hybridsequential0_dense0_fwd/hybridsequential0_dense0_weight', 'hybridsequential0_dense0_fwd/hybridsequential0_dense0_bias'], attr={'param': AttrValue( s='{ flatten : True , no_bias : False , ' 'num_hidden : 128 }'.encode(encoding='utf-8'))}), NodeDef(name='hybridsequential0_dense0_relu_fwd/hybridsequential0_dense0_relu_fwd', op='Activation', input=['hybridsequential0_dense0_fwd/hybridsequential0_dense0_fwd'], attr={'param': AttrValue( s='{ act_type : relu }'.encode(encoding='utf-8'))}) ] # check _get_nodes_from_symbol for expected_node, node in zip(expected_nodes, nodes): assert expected_node == node # check _sym2pb expected_graph = GraphDef(node=expected_nodes, versions=VersionDef(producer=100)) graph = _net2pb(net) assert expected_graph == graph # check add_graph with SummaryWriter(logdir=_LOGDIR) as sw: sw.add_graph(net) check_event_file_and_remove_logdir()