self.training_trainer = SupervisedTrainer
        self.training_epochs = 10
        self.training_batch_size = 32


# Should be in a model.py
class ImageClassifierSimple(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layers = Sequential(
            ImageConversion(standardize=False, to_channel_first=True),
            Conv2D(kernel_size=(3, 3), filters=12), Activation("relu"),
            MaxPooling2D(), BatchNormalization(),
            Conv2D(kernel_size=(3, 3), filters=18), Activation("relu"),
            MaxPooling2D(), BatchNormalization(),
            Conv2D(kernel_size=(3, 3), filters=18), Activation("relu"),
            MaxPooling2D(), BatchNormalization(),
            Conv2D(kernel_size=(3, 3), filters=18), Activation("relu"),
            MaxPooling2D(), BatchNormalization(), Flatten(), Dense(18),
            Activation("relu"), Dense(10), Activation("softmax", dim=1))

    def forward(self, image):
        return self.layers(image)


# Run with parameters parsed from commandline.
# python -m deeptech.examples.mnist_custom_model --mode=train --input=Datasets --output=Results
if __name__ == "__main__":
    cli.run(FashionMNISTConfig)
        self.training_initial_lr = 0.001

    def create_loss(self, model):
        rpn_loss = DetectionLoss(anchors="rpn_anchors",
                                 pred_boxes="rpn_deltas",
                                 pred_class_ids="rpn_class_ids",
                                 target_boxes="boxes",
                                 target_class_ids="fg_bg_classes",
                                 channel_last_gt=True,
                                 lower_tresh=0.3,
                                 upper_tresh=0.5,
                                 delta_preds=not self.model_log_delta_preds,
                                 log_delta_preds=self.model_log_delta_preds)
        final_loss = DetectionLoss(anchors="final_anchors",
                                   pred_boxes="final_deltas",
                                   pred_class_ids="final_class_ids",
                                   target_boxes="boxes",
                                   target_class_ids="class_ids",
                                   channel_last_gt=True,
                                   lower_tresh=0.5,
                                   upper_tresh=0.7,
                                   delta_preds=not self.model_log_delta_preds,
                                   log_delta_preds=self.model_log_delta_preds)
        return MultiLoss(model, rpn=rpn_loss, final=final_loss)


# Run with parameters parsed from commandline.
# python -m deeptech.examples.mnist_custom_loss --mode=train --input=Datasets --output=Results
if __name__ == "__main__":
    cli.run(COCOFasterRCNNConfig)