from Tree import Tree import mxnet as mx import numpy as np from mxnet import gluon, nd # %% tree = Tree() tree.collect_params().initialize(force_reinit=True) tree.collect_params() # %% root = next(iter(tree._structure.items()))[0] tree._prune(root) tree._prune(node) tree.collect_params() tree._structure[node] next(iter(tree._structure.items()))[0]._box._parent tree._weightlayer node = tree._embeddlayer._children['1'] tree._routerlayer._children.keys() tree._weightlayer._children.keys() tree._embeddlayer._children.keys() # %%
if hasattr(node, "_decision") else None for node in tree._embeddlayer._children.values() ] size = len(after) if (len(tree._embeddlayer) > 1): mode = max(set([x for x in after if x is not None]), key=after.count) after.count(mode) hit_value = mode if after.count(mode) > 1 else None # hit_value = max(set([x for x in after if x is not None]), key = after.count) for node, value in zip(list(tree._embeddlayer._children.values()), after): if (value == hit_value or value == 0): tree._prune(node) print(len(tree._routerlayer)) print("done") # %% def traverse(node=next(iter(tree._structure.items()))[0]): print("box") print((node._box._min_list.data() if node._box._min_list.shape is not None else None, node._box._max_list.data() if node._box._max_list.shape is not None else None)) children = tree._structure[node]