コード例 #1
0
    def initialize(self):
        for column_name, column_type in self.input_node.column_types.items():
            if column_type == 'categorical':
                self.categorical_col.append(
                    self.input_node.column_names.index(column_name))
            elif column_type == 'numerical':
                self.numerical_col.append(
                    self.input_node.column_names.index(column_name))
            else:
                raise ValueError('Unsupported column type: '
                                 '{type}'.format(type=column_type))

        for index, cat_col_index1 in enumerate(self.categorical_col):
            self.label_encoders[cat_col_index1] = encoder.LabelEncoder()
            self.value_counters[cat_col_index1] = collections.defaultdict(
                return_zero)
            self.count_frequency[cat_col_index1] = {}
            for cat_col_index2 in self.categorical_col[index + 1:]:
                self.categorical_categorical[(
                    cat_col_index1,
                    cat_col_index2)] = collections.defaultdict(return_zero)
            for num_col_index in self.numerical_col:
                self.numerical_categorical[(
                    num_col_index,
                    cat_col_index1)] = collections.defaultdict(return_zero)
コード例 #2
0
 def _fit(self, y):
     super()._fit(y)
     if isinstance(y, tf.data.Dataset):
         if not self.num_classes:
             for y in tf.data.Dataset:
                 shape = y.shape[0]
                 break
             if shape == 1:
                 self.num_classes = 2
             else:
                 self.num_classes = shape
         self.set_loss()
         return
     if isinstance(y, pd.DataFrame):
         y = y.values
     if isinstance(y, pd.Series):
         y = y.values.reshape(-1, 1)
     # Not label.
     if len(y.flatten()) != len(y):
         self.num_classes = y.shape[1]
         self.set_loss()
         return
     labels = set(y.flatten())
     if self.num_classes is None:
         self.num_classes = len(labels)
     if self.num_classes == 2:
         self.label_encoder = encoder.LabelEncoder()
     elif self.num_classes > 2:
         self.label_encoder = encoder.OneHotEncoder()
     self.set_loss()
     self.label_encoder.fit_with_labels(y)
コード例 #3
0
 def set_state(self, state):
     super().set_state(state)
     self.num_classes = state['num_classes']
     self.multi_label = state['multi_label']
     self.dropout_rate = state['dropout_rate']
     self.label_encoder = None
     if state['label_encoder']:
         self.label_encoder = encoder.LabelEncoder()
         self.label_encoder.set_state(state['label_encoder'])
コード例 #4
0
ファイル: head.py プロジェクト: wushicanASL/autokeras
 def set_state(self, state):
     super().set_state(state)
     self.num_classes = state['num_classes']
     self.multi_label = state['multi_label']
     self.dropout_rate = state['dropout_rate']
     self.label_encoder = None
     if state['label_encoder']:
         if state['encoder_class'] == 'one_hot_encoder':
             self.label_encoder = encoder.OneHotEncoder()
         else:
             self.label_encoder = encoder.LabelEncoder()
         self.label_encoder.set_state(state['label_encoder'])
コード例 #5
0
ファイル: preprocessor.py プロジェクト: wushicanASL/autokeras
 def set_weights(self, weights):
     for key, label_encoder_state in utils.to_type_key(weights['label_encoders'],
                                                       int).items():
         self.label_encoders[key] = encoder.LabelEncoder()
         self.label_encoders[key].set_state(label_encoder_state)
     self.shape = weights['shape']
     self.num_rows = weights['num_rows']
     self.categorical_col = weights['categorical_col']
     self.numerical_col = weights['numerical_col']
     self.value_counters = utils.to_type_key(weights['value_counters'], int)
     self.categorical_categorical = utils.to_type_key(
         weights['categorical_categorical'], ast.literal_eval)
     self.numerical_categorical = utils.to_type_key(
         weights['numerical_categorical'], ast.literal_eval)
     self.count_frequency = utils.to_type_key(weights['count_frequency'], int)
     self.high_level1_col = weights['high_level1_col']
     self.high_level2_col = weights['high_level2_col']
     self.high_level_cat_cat = utils.to_type_key(
         weights['high_level_cat_cat'], ast.literal_eval)
     self.high_level_num_cat = utils.to_type_key(
         weights['high_level_num_cat'], ast.literal_eval)