def custom_train_step(self, data): """ Custom training logic :param data: :return: """ data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: y_pred = self.keras_model(x, training=True) loss = self.keras_model.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.keras_model.losses) if self.task == 'regression': variance_loss = mse_var_wrapper(y_pred[0], x['labels_err']) output_loss = mse_lin_wrapper(y_pred[1], x['labels_err']) elif self.task == 'classification': output_loss = bayesian_categorical_crossentropy_wrapper( y_pred[1]) variance_loss = bayesian_categorical_crossentropy_var_wrapper( y_pred[0]) elif self.task == 'binary_classification': output_loss = bayesian_binary_crossentropy_wrapper(y_pred[1]) variance_loss = bayesian_binary_crossentropy_var_wrapper( y_pred[0]) else: raise RuntimeError( 'Only "regression", "classification" and "binary_classification" are supported' ) loss = output_loss(y['output'], y_pred[0]) + variance_loss( y['variance_output'], y_pred[1]) # apply gradient here if version.parse(tf.__version__) >= version.parse("2.4.0"): self.keras_model.optimizer.minimize( loss, self.keras_model.trainable_variables, tape=tape) else: tf.python.keras.engine.training._minimize( self.keras_model.distribute_strategy, tape, self.keras_model.optimizer, loss, self.keras_model.trainable_variables) self.keras_model.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.keras_model.metrics}
def model(self): input_tensor = Input(shape=self._input_shape, name='input') labels_err_tensor = Input(shape=(self._labels_shape,), name='labels_err') cnn_layer_1 = Conv1D(kernel_initializer=self.initializer, padding="same", filters=self.num_filters[0], kernel_size=self.filter_len, kernel_regularizer=regularizers.l2(self.l2))(input_tensor) activation_1 = Activation(activation=self.activation)(cnn_layer_1) dropout_1 = MCDropout(self.dropout_rate, disable=self.disable_dropout)(activation_1) cnn_layer_2 = Conv1D(kernel_initializer=self.initializer, padding="same", filters=self.num_filters[1], kernel_size=self.filter_len, kernel_regularizer=regularizers.l2(self.l2))(dropout_1) activation_2 = Activation(activation=self.activation)(cnn_layer_2) maxpool_1 = MaxPooling1D(pool_size=self.pool_length)(activation_2) flattener = Flatten()(maxpool_1) dropout_2 = MCDropout(self.dropout_rate, disable=self.disable_dropout)(flattener) layer_3 = Dense(units=self.num_hidden[0], kernel_regularizer=regularizers.l2(self.l2), kernel_initializer=self.initializer, activation=self.activation)(dropout_2) activation_3 = Activation(activation=self.activation)(layer_3) dropout_3 = MCDropout(self.dropout_rate, disable=self.disable_dropout)(activation_3) layer_4 = Dense(units=self.num_hidden[1], kernel_regularizer=regularizers.l2(self.l2), kernel_initializer=self.initializer, activation=self.activation)(dropout_3) activation_4 = Activation(activation=self.activation)(layer_4) output = Dense(units=self._labels_shape, name='output')(activation_4) output_activated = Activation(activation=self._last_layer_activation)(output) variance_output = Dense(units=self._labels_shape, activation='linear', name='variance_output')(activation_4) model = Model(inputs=[input_tensor, labels_err_tensor], outputs=[output, variance_output]) # new astroNN high performance dropout variational inference on GPU expects single output model_prediction = Model(inputs=[input_tensor], outputs=concatenate([output, variance_output])) if self.task == 'regression': variance_loss = mse_var_wrapper(output, labels_err_tensor) output_loss = mse_lin_wrapper(variance_output, labels_err_tensor) elif self.task == 'classification': output_loss = bayesian_categorical_crossentropy_wrapper(variance_output) variance_loss = bayesian_categorical_crossentropy_var_wrapper(output) elif self.task == 'binary_classification': output_loss = bayesian_binary_crossentropy_wrapper(variance_output) variance_loss = bayesian_binary_crossentropy_var_wrapper(output) else: raise RuntimeError('Only "regression", "classification" and "binary_classification" are supported') return model, model_prediction, output_loss, variance_loss