Пример #1
0
 def finalize(self):
     self._output_shape = self.targets[0].shape
     if self.num_classes > 2:
         self._one_hot_encoder = encoder.OneHotEncoder()
         self._one_hot_encoder.fit_with_one_hot_encoded(self.targets)
         self.targets = self._one_hot_encoder.decode(self.targets)
     super().finalize()
Пример #2
0
 def _fit(self, y):
     super()._fit(y)
     if isinstance(y, tf.data.Dataset):
         if not self.num_classes:
             for y in tf.data.Dataset:
                 shape = y.shape[0]
                 break
             if shape == 1:
                 self.num_classes = 2
             else:
                 self.num_classes = shape
         self.set_loss()
         return
     if isinstance(y, pd.DataFrame):
         y = y.values
     if isinstance(y, pd.Series):
         y = y.values.reshape(-1, 1)
     # Not label.
     if len(y.flatten()) != len(y):
         self.num_classes = y.shape[1]
         self.set_loss()
         return
     labels = set(y.flatten())
     if self.num_classes is None:
         self.num_classes = len(labels)
     if self.num_classes == 2:
         self.label_encoder = encoder.LabelEncoder()
     elif self.num_classes > 2:
         self.label_encoder = encoder.OneHotEncoder()
     self.set_loss()
     self.label_encoder.fit_with_labels(y)
Пример #3
0
 def set_state(self, state):
     super().set_state(state)
     self.num_classes = state['num_classes']
     self.multi_label = state['multi_label']
     self.dropout_rate = state['dropout_rate']
     self.label_encoder = None
     if state['label_encoder']:
         if state['encoder_class'] == 'one_hot_encoder':
             self.label_encoder = encoder.OneHotEncoder()
         else:
             self.label_encoder = encoder.LabelEncoder()
         self.label_encoder.set_state(state['label_encoder'])
Пример #4
0
 def clear_weights(self):
     super().clear_weights()
     self.lgbm = lgb.LGBMClassifier(random_state=self.seed)
     self._one_hot_encoder = encoder.OneHotEncoder()
Пример #5
0
 def set_weights(self, weights):
     super().set_weights(weights)
     self._one_hot_encoder = encoder.OneHotEncoder()
     self._one_hot_encoder.set_state(weights['one_hot_encoder'])
     self.num_classes = weights['num_classes']