def inference_graph(self, data): with ops.device(self.device_assigner): path_probability, path = gen_training_ops.hard_routing_function( data, self.tree_parameters, self.tree_thresholds, max_nodes=self.params.num_nodes, tree_depth=self.params.hybrid_tree_depth) output = array_ops.slice( gen_training_ops.unpack_path(path, path_probability), [0, self.params.num_nodes - self.params.num_leaves - 1], [-1, self.params.num_leaves]) return output
def soft_inference_graph(self, data): with ops.device(self.device_assigner): path_probability, path = ( gen_training_ops.stochastic_hard_routing_function( data, self.tree_parameters, self.tree_thresholds, tree_depth=self.params.hybrid_tree_depth, random_seed=self.params.base_random_seed)) output = array_ops.slice( gen_training_ops.unpack_path(path, path_probability), [0, self.params.num_nodes - self.params.num_leaves - 1], [-1, self.params.num_leaves]) return output