Exemplo n.º 1
0
 def __init__(self,
              params,
              tree_configs=None,
              tree_stats=None,
              device_assigner=None,
              variables=None,
              tree_variables_class=TreeVariables,
              tree_graphs=None,
              training=True):
     self.params = params
     self.device_assigner = (device_assigner
                             or framework_variables.VariableDeviceChooser())
     logging.info('Constructing forest with params = ')
     logging.info(self.params.__dict__)
     self.variables = variables or ForestVariables(
         self.params,
         device_assigner=self.device_assigner,
         training=training,
         tree_variables_class=tree_variables_class,
         tree_configs=tree_configs,
         tree_stats=tree_stats)
     tree_graph_class = tree_graphs or RandomTreeGraphs
     self.trees = [
         tree_graph_class(self.variables[i], self.params, i)
         for i in range(self.params.num_trees)
     ]
Exemplo n.º 2
0
    def __init__(self,
                 params,
                 device_assigner=None,
                 optimizer_class=adagrad.AdagradOptimizer,
                 **kwargs):

        self.device_assigner = (device_assigner
                                or framework_variables.VariableDeviceChooser())

        self.params = params

        self.optimizer = optimizer_class(self.params.learning_rate)

        self.is_regression = params.regression

        self.regularizer = None
        if params.regularization == "l1":
            self.regularizer = layers.l1_regularizer(
                self.params.regularization_strength)
        elif params.regularization == "l2":
            self.regularizer = layers.l2_regularizer(
                self.params.regularization_strength)
Exemplo n.º 3
0
 def __init__(self, params, layer_num, device_assigner, *args, **kwargs):
   self.layer_num = layer_num
   self.device_assigner = (
       device_assigner or framework_variables.VariableDeviceChooser())
   self.params = params
   self._define_vars(params, **kwargs)