Example #1
0
    def inference_graph(self, data):
        with ops.device(self.device_assigner):
            path_probability, path = gen_training_ops.hard_routing_function(
                data,
                self.tree_parameters,
                self.tree_thresholds,
                max_nodes=self.params.num_nodes,
                tree_depth=self.params.hybrid_tree_depth)

            output = array_ops.slice(
                gen_training_ops.unpack_path(path, path_probability),
                [0, self.params.num_nodes - self.params.num_leaves - 1],
                [-1, self.params.num_leaves])

            return output
  def inference_graph(self, data):
    with ops.device(self.device_assigner):
      path_probability, path = gen_training_ops.hard_routing_function(
          data,
          self.tree_parameters,
          self.tree_thresholds,
          max_nodes=self.params.num_nodes,
          tree_depth=self.params.hybrid_tree_depth)

      output = array_ops.slice(
          gen_training_ops.unpack_path(path, path_probability),
          [0, self.params.num_nodes - self.params.num_leaves - 1],
          [-1, self.params.num_leaves])

      return output
Example #3
0
    def soft_inference_graph(self, data):
        with ops.device(self.device_assigner):
            path_probability, path = (
                gen_training_ops.stochastic_hard_routing_function(
                    data,
                    self.tree_parameters,
                    self.tree_thresholds,
                    tree_depth=self.params.hybrid_tree_depth,
                    random_seed=self.params.base_random_seed))

            output = array_ops.slice(
                gen_training_ops.unpack_path(path, path_probability),
                [0, self.params.num_nodes - self.params.num_leaves - 1],
                [-1, self.params.num_leaves])

            return output
  def soft_inference_graph(self, data):
    with ops.device(self.device_assigner):
      path_probability, path = (
          gen_training_ops.stochastic_hard_routing_function(
              data,
              self.tree_parameters,
              self.tree_thresholds,
              tree_depth=self.params.hybrid_tree_depth,
              random_seed=self.params.base_random_seed))

      output = array_ops.slice(
          gen_training_ops.unpack_path(path, path_probability),
          [0, self.params.num_nodes - self.params.num_leaves - 1],
          [-1, self.params.num_leaves])

      return output