def initialize(self): for column_name, column_type in self.input_node.column_types.items(): if column_type == 'categorical': self.categorical_col.append( self.input_node.column_names.index(column_name)) elif column_type == 'numerical': self.numerical_col.append( self.input_node.column_names.index(column_name)) else: raise ValueError('Unsupported column type: ' '{type}'.format(type=column_type)) for index, cat_col_index1 in enumerate(self.categorical_col): self.label_encoders[cat_col_index1] = encoder.LabelEncoder() self.value_counters[cat_col_index1] = collections.defaultdict( return_zero) self.count_frequency[cat_col_index1] = {} for cat_col_index2 in self.categorical_col[index + 1:]: self.categorical_categorical[( cat_col_index1, cat_col_index2)] = collections.defaultdict(return_zero) for num_col_index in self.numerical_col: self.numerical_categorical[( num_col_index, cat_col_index1)] = collections.defaultdict(return_zero)
def _fit(self, y): super()._fit(y) if isinstance(y, tf.data.Dataset): if not self.num_classes: for y in tf.data.Dataset: shape = y.shape[0] break if shape == 1: self.num_classes = 2 else: self.num_classes = shape self.set_loss() return if isinstance(y, pd.DataFrame): y = y.values if isinstance(y, pd.Series): y = y.values.reshape(-1, 1) # Not label. if len(y.flatten()) != len(y): self.num_classes = y.shape[1] self.set_loss() return labels = set(y.flatten()) if self.num_classes is None: self.num_classes = len(labels) if self.num_classes == 2: self.label_encoder = encoder.LabelEncoder() elif self.num_classes > 2: self.label_encoder = encoder.OneHotEncoder() self.set_loss() self.label_encoder.fit_with_labels(y)
def set_state(self, state): super().set_state(state) self.num_classes = state['num_classes'] self.multi_label = state['multi_label'] self.dropout_rate = state['dropout_rate'] self.label_encoder = None if state['label_encoder']: self.label_encoder = encoder.LabelEncoder() self.label_encoder.set_state(state['label_encoder'])
def set_state(self, state): super().set_state(state) self.num_classes = state['num_classes'] self.multi_label = state['multi_label'] self.dropout_rate = state['dropout_rate'] self.label_encoder = None if state['label_encoder']: if state['encoder_class'] == 'one_hot_encoder': self.label_encoder = encoder.OneHotEncoder() else: self.label_encoder = encoder.LabelEncoder() self.label_encoder.set_state(state['label_encoder'])
def set_weights(self, weights): for key, label_encoder_state in utils.to_type_key(weights['label_encoders'], int).items(): self.label_encoders[key] = encoder.LabelEncoder() self.label_encoders[key].set_state(label_encoder_state) self.shape = weights['shape'] self.num_rows = weights['num_rows'] self.categorical_col = weights['categorical_col'] self.numerical_col = weights['numerical_col'] self.value_counters = utils.to_type_key(weights['value_counters'], int) self.categorical_categorical = utils.to_type_key( weights['categorical_categorical'], ast.literal_eval) self.numerical_categorical = utils.to_type_key( weights['numerical_categorical'], ast.literal_eval) self.count_frequency = utils.to_type_key(weights['count_frequency'], int) self.high_level1_col = weights['high_level1_col'] self.high_level2_col = weights['high_level2_col'] self.high_level_cat_cat = utils.to_type_key( weights['high_level_cat_cat'], ast.literal_eval) self.high_level_num_cat = utils.to_type_key( weights['high_level_num_cat'], ast.literal_eval)