print("split") print((node._decision._dim.data(), node._decision._split.data())) left = next(key for key, value in children.items() if value == -1) right = next(key for key, value in children.items() if value == 1) print("left") traverse(left) print("right") traverse(right) traverse() # %% root = next(iter(tree._structure.items()))[0] router, router_mat, weight, embedd = tree._contextify( nd.array([[1, 1], [2, 2], [-1, -1]]))(root) nd.sum(router_mat, axis=-1) # %% root = next(iter(tree._structure.items()))[0] router_d, router_mat_d, weight_d, embedd_d = tree._contextify( nd.array([[1, 1], [2, 2], [-1, -1]]))(root) router = nd.stack(*[router_d[key] for key in sorted(router_d)], axis=-1) weight = nd.stack(*[weight_d[key] for key in sorted(weight_d)], axis=-1) embedd = nd.stack(*[embedd_d[key] for key in sorted(embedd_d)], axis=0) router_mat = nd.stack(*[router_mat_d[key] for key in sorted(router_mat_d)], axis=1)
mode = max(set([x for x in after if x is not None]), key=after.count) after.count(mode) hitlist = mode if after.count(mode) > 1 else None for node, value in zip(list(tree._embeddlayer._children.values()), after): print(value) for node, value in zip(list(tree._embeddlayer._children.values()), after): if (value == hitlist): tree._prune(node) # %% root = next(iter(tree._structure.items()))[0] router_d, router_mat_d, weight_d, embedd_d = tree._contextify( nd.array([[1.75]]))(root) router = nd.stack(*[router_d[key] for key in sorted(router_d)], axis=-1) weight = nd.stack(*[weight_d[key] for key in sorted(weight_d)], axis=-1) embedd = nd.stack(*[embedd_d[key] for key in sorted(embedd_d)], axis=0) router_mat = nd.stack(*[router_mat_d[key] for key in sorted(router_mat_d)], axis=1) where = nd.argmin(nd.abs(router + 0.5), axis=1) head = nd.concat(*[router_mat[i][k] for i, k in enumerate(where)], dim=0) # %%